mod builtins;
mod types;
#[cfg(test)]
mod test;
pub use types::*;
use builtins::{CallCounters, HostContext};
use std::collections::BTreeMap;
use std::sync::{Arc, Mutex};
use std::time::Instant;
use rhai::{Dynamic, Engine, EvalAltResult, Scope};
use crate::error::{Result, TinyAgentsError};
use crate::harness::context::RunContext;
use crate::harness::events::EventSink;
use crate::harness::ids::{SessionId, new_session_id};
#[derive(Clone, Default)]
pub(super) struct CellBuffers {
stdout: Arc<Mutex<String>>,
calls: Arc<Mutex<Vec<ReplCallRecord>>>,
answer: Arc<Mutex<Option<String>>>,
host_error: Arc<Mutex<Option<TinyAgentsError>>>,
vars_snapshot: Arc<Mutex<BTreeMap<String, String>>>,
}
pub struct ReplVariables {
scope: Scope<'static>,
reserved_baseline: BTreeMap<String, Dynamic>,
}
impl ReplVariables {
fn seeded() -> Self {
let mut scope = Scope::new();
let mut reserved_baseline = BTreeMap::new();
for name in reserved_names() {
let value = Dynamic::UNIT;
scope.push(name.to_string(), value.clone());
reserved_baseline.insert(name.to_string(), value);
}
Self {
scope,
reserved_baseline,
}
}
pub fn set(&mut self, name: impl Into<String>, value: ReplValue) -> Result<()> {
let name = name.into();
if reserved_names().any(|r| name == r) {
return Err(TinyAgentsError::Capability(format!(
"`{name}` is a reserved REPL name and cannot be set as a variable"
)));
}
self.scope.set_value(name, repl_value_to_dynamic(&value));
Ok(())
}
pub fn get(&self, name: &str) -> Option<ReplValue> {
self.scope
.get_value::<Dynamic>(name)
.map(|d| dynamic_to_repl_value(&d))
}
fn set_reserved(&mut self, name: &str, value: Dynamic) {
self.scope.set_value(name.to_string(), value.clone());
self.reserved_baseline.insert(name.to_string(), value);
}
fn snapshot(&self) -> BTreeMap<String, String> {
let mut map = BTreeMap::new();
for (name, _is_const, value) in self.scope.iter() {
map.insert(name.to_string(), format!("{value:?}"));
}
map
}
fn restore_reserved(&mut self) {
for (name, value) in &self.reserved_baseline {
self.scope.set_value(name.clone(), value.clone());
}
}
}
impl Default for ReplVariables {
fn default() -> Self {
Self::seeded()
}
}
pub struct ReplSession<State = (), Ctx = ()>
where
State: Send + Sync,
{
pub session_id: SessionId,
pub run_context: RunContext<Ctx>,
pub variables: ReplVariables,
pub capabilities: ReplCapabilities<State>,
pub policy: ReplPolicy,
pub events: EventSink,
state: Arc<State>,
counters: Arc<Mutex<CallCounters>>,
drafts: Arc<Mutex<BTreeMap<String, GraphBlueprintHandle>>>,
engine: Engine,
buffers: CellBuffers,
}
impl<State: Send + Sync + Default + 'static> ReplSession<State, ()> {
pub fn new() -> Self {
Self::from_parts(
ReplCapabilities::default(),
ReplPolicy::default(),
RunContext::new(
crate::harness::context::RunConfig::new(format!(
"repl-run-{}",
crate::harness::ids::next_seq()
)),
(),
),
)
}
}
impl<State: Send + Sync + Default + 'static> Default for ReplSession<State, ()> {
fn default() -> Self {
Self::new()
}
}
impl<State: Send + Sync + Default + 'static, Ctx> ReplSession<State, Ctx> {
pub fn from_parts(
capabilities: ReplCapabilities<State>,
policy: ReplPolicy,
run_context: RunContext<Ctx>,
) -> Self {
let buffers = CellBuffers::default();
let events = run_context.events.clone();
let mut session = Self {
session_id: new_session_id(),
run_context,
variables: ReplVariables::seeded(),
capabilities,
policy,
events,
state: Arc::new(State::default()),
counters: Arc::new(Mutex::new(CallCounters::default())),
drafts: Arc::new(Mutex::new(BTreeMap::new())),
engine: Engine::new(),
buffers,
};
session.rebuild_engine();
session
}
}
impl<State: Send + Sync + 'static, Ctx> ReplSession<State, Ctx> {
fn rebuild_engine(&mut self) {
let ctx = Arc::new(HostContext {
registry: self.capabilities.registry.clone(),
state: self.state.clone(),
policy: self.policy.clone(),
language: self.capabilities.language.clone(),
session_label: self.session_id.as_str().to_string(),
run_depth: self.run_context.config.depth,
events: self.events.clone(),
buffers: self.buffers.clone(),
counters: self.counters.clone(),
drafts: self.drafts.clone(),
});
self.engine = builtins::build_engine(ctx);
}
pub fn with_policy(mut self, policy: ReplPolicy) -> Self {
self.policy = policy;
self.rebuild_engine();
self
}
pub fn with_capabilities(mut self, capabilities: ReplCapabilities<State>) -> Self {
self.capabilities = capabilities;
self.rebuild_engine();
self
}
pub fn with_state(mut self, state: Arc<State>) -> Self {
self.state = state;
self.rebuild_engine();
self
}
pub fn app_state(&self) -> Arc<State> {
self.state.clone()
}
pub fn set_context(&mut self, value: ReplValue) {
self.variables
.set_reserved("context", repl_value_to_dynamic(&value));
}
pub fn set_state_var(&mut self, value: ReplValue) {
self.variables
.set_reserved("state", repl_value_to_dynamic(&value));
}
pub fn eval_cell(&mut self, script: &str) -> Result<ReplResult> {
let start = Instant::now();
if script.len() > self.policy.max_script_bytes {
return Err(TinyAgentsError::LimitExceeded(format!(
"ragsh cell is {} bytes, exceeding the max_script_bytes limit of {}",
script.len(),
self.policy.max_script_bytes
)));
}
self.buffers.reset();
let before = self.variables.snapshot();
*self
.buffers
.vars_snapshot
.lock()
.expect("vars_snapshot poisoned") = before.clone();
let eval = self
.engine
.eval_with_scope::<Dynamic>(&mut self.variables.scope, script);
self.variables.restore_reserved();
let value_dynamic = match eval {
Ok(value) => value,
Err(err) => {
if let Some(host_err) = self.buffers.take_host_error() {
return Err(host_err);
}
return Err(map_rhai_error(*err));
}
};
let value = if value_dynamic.is_unit() {
None
} else {
Some(dynamic_to_repl_value(&value_dynamic))
};
let stdout = self.buffers.stdout();
let calls = self.buffers.take_calls();
let final_answer = self.buffers.answer();
let value_bytes = value.as_ref().map(ReplValue::byte_len).unwrap_or(0);
if stdout.len() + value_bytes > self.policy.max_output_bytes {
return Err(TinyAgentsError::LimitExceeded(format!(
"ragsh cell produced {} bytes of output, exceeding the max_output_bytes limit of {}",
stdout.len() + value_bytes,
self.policy.max_output_bytes
)));
}
let after = self.variables.snapshot();
let variables_changed = diff_changed(&before, &after);
Ok(ReplResult {
stdout,
value,
variables_changed,
calls,
final_answer,
elapsed: start.elapsed(),
})
}
}
impl CellBuffers {
fn reset(&self) {
self.stdout.lock().expect("stdout poisoned").clear();
self.calls.lock().expect("calls poisoned").clear();
*self.answer.lock().expect("answer poisoned") = None;
*self.host_error.lock().expect("host_error poisoned") = None;
}
fn stdout(&self) -> String {
self.stdout.lock().expect("stdout poisoned").clone()
}
fn take_calls(&self) -> Vec<ReplCallRecord> {
std::mem::take(&mut *self.calls.lock().expect("calls poisoned"))
}
fn answer(&self) -> Option<String> {
self.answer.lock().expect("answer poisoned").clone()
}
fn take_host_error(&self) -> Option<TinyAgentsError> {
self.host_error.lock().expect("host_error poisoned").take()
}
pub(super) fn push_call(&self, record: ReplCallRecord) {
self.calls.lock().expect("calls poisoned").push(record);
}
pub(super) fn push_stdout_line(&self, line: &str) {
let mut out = self.stdout.lock().expect("stdout poisoned");
out.push_str(line);
out.push('\n');
}
pub(super) fn set_answer(&self, content: String) {
*self.answer.lock().expect("answer poisoned") = Some(content);
}
pub(super) fn set_host_error(&self, err: TinyAgentsError) {
*self.host_error.lock().expect("host_error poisoned") = Some(err);
}
pub(super) fn vars_snapshot(&self) -> BTreeMap<String, String> {
self.vars_snapshot
.lock()
.expect("vars_snapshot poisoned")
.clone()
}
}
fn map_rhai_error(err: EvalAltResult) -> TinyAgentsError {
match err {
EvalAltResult::ErrorTooManyOperations(pos) => TinyAgentsError::LimitExceeded(format!(
"ragsh cell exceeded the operation limit (max_operations) at {pos}"
)),
other => TinyAgentsError::Validation(format!("ragsh evaluation error: {other}")),
}
}
fn diff_changed(
before: &BTreeMap<String, String>,
after: &BTreeMap<String, String>,
) -> Vec<String> {
let mut changed: Vec<String> = after
.iter()
.filter(|(name, value)| {
!reserved_names().any(|r| r == name.as_str())
&& before.get(*name).map(|b| b != *value).unwrap_or(true)
})
.map(|(name, _)| name.clone())
.collect();
changed.sort();
changed.dedup();
changed
}
pub(super) fn repl_value_to_dynamic(value: &ReplValue) -> Dynamic {
match value {
ReplValue::Unit => Dynamic::UNIT,
ReplValue::Bool(b) => Dynamic::from_bool(*b),
ReplValue::Int(i) => Dynamic::from_int(*i),
ReplValue::Float(f) => Dynamic::from_float(*f),
ReplValue::String(s) => Dynamic::from(s.clone()),
ReplValue::Array(items) => {
let arr: rhai::Array = items.iter().map(repl_value_to_dynamic).collect();
Dynamic::from_array(arr)
}
ReplValue::Map(map) => {
let mut rmap = rhai::Map::new();
for (k, v) in map {
rmap.insert(k.as_str().into(), repl_value_to_dynamic(v));
}
Dynamic::from_map(rmap)
}
}
}
pub(super) fn dynamic_to_repl_value(value: &Dynamic) -> ReplValue {
if value.is_unit() {
return ReplValue::Unit;
}
if value.is_bool() {
return ReplValue::Bool(value.as_bool().unwrap_or(false));
}
if value.is_int() {
return ReplValue::Int(value.as_int().unwrap_or(0));
}
if value.is_float() {
return ReplValue::Float(value.as_float().unwrap_or(0.0));
}
if value.is_string() {
return ReplValue::String(value.clone().into_string().unwrap_or_default());
}
if value.is_array() {
let arr = value.clone().into_array().unwrap_or_default();
return ReplValue::Array(arr.iter().map(dynamic_to_repl_value).collect());
}
if value.is_map()
&& let Some(map) = value.read_lock::<rhai::Map>()
{
let mut out = BTreeMap::new();
for (k, v) in map.iter() {
out.insert(k.to_string(), dynamic_to_repl_value(v));
}
return ReplValue::Map(out);
}
ReplValue::String(value.to_string())
}
pub(super) fn json_to_repl_value(value: &serde_json::Value) -> ReplValue {
match value {
serde_json::Value::Null => ReplValue::Unit,
serde_json::Value::Bool(b) => ReplValue::Bool(*b),
serde_json::Value::Number(n) => {
if let Some(i) = n.as_i64() {
ReplValue::Int(i)
} else {
ReplValue::Float(n.as_f64().unwrap_or(0.0))
}
}
serde_json::Value::String(s) => ReplValue::String(s.clone()),
serde_json::Value::Array(items) => {
ReplValue::Array(items.iter().map(json_to_repl_value).collect())
}
serde_json::Value::Object(map) => ReplValue::Map(
map.iter()
.map(|(k, v)| (k.clone(), json_to_repl_value(v)))
.collect(),
),
}
}