#[path = "hang_preventing_executor.rs"]
pub mod hang_preventing_executor;
use std::path::{Path, PathBuf};
use std::sync::Once;
use std::time::{Duration, Instant};
use futures::StreamExt;
use objectiveai_sdk::cli::command::binary::BinaryExecutor;
use objectiveai_sdk::cli::command::{CommandExecutor, CommandRequest, CommandResponse};
pub use hang_preventing_executor::HangPreventingBinaryCommandExecutor;
fn sync_snapshots_env() {
static ONCE: Once = Once::new();
ONCE.call_once(|| {
let mode = if std::env::var("UPDATE_SNAPSHOTS").as_deref() == Ok("1") {
"always"
} else {
"no"
};
unsafe { std::env::set_var("INSTA_UPDATE", mode) };
});
}
pub fn test_api_address() -> Option<String> {
let port = std::env::var("OBJECTIVEAI_TEST_PORT").ok()?;
Some(format!("http://127.0.0.1:{port}"))
}
pub fn runtime_dir() -> PathBuf {
Path::new(env!("CARGO_MANIFEST_DIR")).join(".objectiveai-tests")
}
pub fn cli_binary() -> PathBuf {
let mut path = runtime_dir().join("objectiveai-cli");
if cfg!(windows) {
path.set_extension("exe");
}
path
}
pub fn test_base_dir() -> PathBuf {
sync_snapshots_env();
let test = std::thread::current()
.name()
.expect("test thread must have a name")
.to_string();
let dir = runtime_dir().join(&test);
eprintln!("test base dir: {}", dir.display());
dir
}
pub fn executor() -> HangPreventingBinaryCommandExecutor {
executor_with_base_dir(&test_base_dir())
}
pub fn executor_with_base_dir(base_dir: &Path) -> HangPreventingBinaryCommandExecutor {
let mut exec = BinaryExecutor::from_path(cli_binary())
.env("CONFIG_BASE_DIR", base_dir.to_string_lossy().into_owned());
if let Some(addr) = test_api_address() {
exec = exec.env("OBJECTIVEAI_ADDRESS", addr);
}
HangPreventingBinaryCommandExecutor::new(exec, base_dir.to_path_buf())
}
pub async fn collect_stream<E, R, T>(executor: &E, request: R) -> Vec<T>
where
E: CommandExecutor,
E::Error: std::fmt::Debug,
R: CommandRequest + Send,
T: CommandResponse + serde::de::DeserializeOwned + Send + 'static,
{
let stream = executor
.execute::<R, T>(request, None)
.await
.expect("CommandExecutor::execute failed");
let mut stream = std::pin::pin!(stream);
let mut items = Vec::new();
while let Some(item) = stream.next().await {
items.push(item.expect("CommandExecutor stream item was Err"));
}
items
}
pub async fn execute_one<E, R, T>(executor: &E, request: R) -> T
where
E: CommandExecutor,
E::Error: std::fmt::Debug,
R: CommandRequest + Send,
T: CommandResponse + serde::de::DeserializeOwned + Send + 'static,
{
executor
.execute_one::<R, T>(request, None)
.await
.expect("CommandExecutor::execute_one failed")
}
pub async fn db_query<E>(executor: &E, sql: &str) -> Vec<Vec<serde_json::Value>>
where
E: CommandExecutor,
E::Error: std::fmt::Debug,
{
use objectiveai_sdk::cli::command::db::query::{
Path as DbPath, Request as DbReq, Response as DbResp,
};
let req = DbReq {
path_type: DbPath::DbQuery,
query: sql.to_string(),
timeout_seconds: 30,
max_tokens: None,
jq: None,
};
let resp: DbResp = executor
.execute_one(req, None)
.await
.expect("db query executor call");
resp.rows
}
fn sql_escape(s: &str) -> String {
s.replace('\'', "''")
}
pub async fn read_continuation<E>(executor: &E, aih: &str) -> Option<String>
where
E: CommandExecutor,
E::Error: std::fmt::Debug,
{
let sql = format!(
"SELECT continuation FROM agent_continuations \
WHERE agent_instance_hierarchy = '{}'",
sql_escape(aih),
);
let rows = db_query(executor, &sql).await;
rows.into_iter().next().and_then(|mut row| {
row.pop().and_then(|v| v.as_str().map(str::to_string))
})
}
pub async fn wait_for_continuation<E>(executor: &E, aih: &str, timeout: Duration) -> String
where
E: CommandExecutor,
E::Error: std::fmt::Debug,
{
let deadline = Instant::now() + timeout;
loop {
if let Some(s) = read_continuation(executor, aih).await {
if !s.is_empty() {
return s;
}
}
if Instant::now() >= deadline {
panic!(
"no agent_continuations row for {aih} after {:?}",
timeout,
);
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
pub async fn wait_for_request_continuation<E>(
executor: &E,
response_id: &str,
timeout: Duration,
) -> Option<String>
where
E: CommandExecutor,
E::Error: std::fmt::Debug,
{
let sql = format!(
"SELECT body->>'continuation' FROM logs.agent_completion_requests \
WHERE response_id = '{}'",
sql_escape(response_id),
);
let deadline = Instant::now() + timeout;
loop {
let rows = db_query(executor, &sql).await;
if let Some(mut row) = rows.into_iter().next() {
return row.pop().and_then(|v| v.as_str().map(str::to_string));
}
if Instant::now() >= deadline {
panic!(
"no logs.agent_completion_requests row for response_id={response_id} after {:?}",
timeout,
);
}
tokio::time::sleep(Duration::from_millis(100)).await;
}
}
pub async fn tool_call_names_for_response<E>(executor: &E, response_id: &str) -> Vec<String>
where
E: CommandExecutor,
E::Error: std::fmt::Debug,
{
let sql = format!(
"SELECT jsonb_path_query(body, '$.messages[*].tool_calls[*].function.name')::text \
FROM logs.agent_completion_responses WHERE response_id = '{}'",
sql_escape(response_id),
);
let rows = db_query(executor, &sql).await;
rows.into_iter()
.filter_map(|mut row| row.pop())
.filter_map(|v| match v {
serde_json::Value::String(s) => {
Some(s.trim_matches('"').to_string())
}
_ => None,
})
.filter(|s| !s.is_empty())
.collect()
}
pub fn load_snapshot(dir: &Path, name: &str) -> serde_json::Value {
let path = dir.join(format!("{name}.json"));
let content = std::fs::read_to_string(&path)
.unwrap_or_else(|e| panic!("failed to read snapshot {}: {e}", path.display()));
serde_json::from_str(&content).unwrap()
}
pub fn assert_normalized_snapshot<T: serde::Serialize>(
snapshot_path: &Path,
snapshot_name: &str,
normalized: &T,
) {
let expected_raw = std::fs::read_to_string(snapshot_path)
.unwrap_or_else(|e| panic!("read snapshot {}: {e}", snapshot_path.display()));
let expected_value: serde_json::Value = serde_json::from_str(&expected_raw)
.unwrap_or_else(|e| panic!("parse snapshot {}: {e}", snapshot_path.display()));
let expected_rounded = normalize_agent_lineages(&rounded(&expected_value));
let actual_value =
serde_json::to_value(normalized).expect("normalized value serialises");
let actual_rounded = normalize_agent_lineages(&rounded(&actual_value));
if actual_rounded == expected_rounded {
return;
}
let actual_pretty = serde_json::to_string_pretty(&actual_rounded)
.expect("rounded Value serialises to pretty JSON");
let expected_pretty = serde_json::to_string_pretty(&expected_rounded)
.expect("rounded Value serialises to pretty JSON");
let dir = test_base_dir();
std::fs::create_dir_all(&dir)
.unwrap_or_else(|e| panic!("create {} for snapshot diff: {e}", dir.display()));
let actual_path = dir.join(format!("{snapshot_name}.actual.json"));
let expected_path = dir.join(format!("{snapshot_name}.expected.json"));
std::fs::write(&actual_path, &actual_pretty)
.unwrap_or_else(|e| panic!("write {}: {e}", actual_path.display()));
std::fs::write(&expected_path, &expected_pretty)
.unwrap_or_else(|e| panic!("write {}: {e}", expected_path.display()));
panic!(
"snapshot mismatch for `{snapshot_name}`\n \
source: {}\n \
expected: {}\n \
actual: {}\n \
diff: diff -u {} {}\n\
{}",
snapshot_path.display(),
expected_path.display(),
actual_path.display(),
expected_path.display(),
actual_path.display(),
first_diff_lines(&expected_pretty, &actual_pretty, 30),
);
}
fn first_diff_lines(expected: &str, actual: &str, max_lines: usize) -> String {
let mut out = String::from(" first diverging lines:\n");
let mut e_lines = expected.lines();
let mut a_lines = actual.lines();
let mut emitted = 0usize;
let mut line_no = 0usize;
loop {
let el = e_lines.next();
let al = a_lines.next();
line_no += 1;
match (el, al) {
(None, None) => break,
(Some(es), Some(as_)) if es == as_ => continue,
(es, as_) => {
out.push_str(&format!(" L{line_no:>4} - {}\n", es.unwrap_or("<EOF>")));
out.push_str(&format!(" L{line_no:>4} + {}\n", as_.unwrap_or("<EOF>")));
emitted += 1;
if emitted >= max_lines {
out.push_str(&format!(
" … ({} max lines reached; run the diff command above for the full picture)\n",
max_lines
));
break;
}
}
}
}
if emitted == 0 {
out.push_str(" (no line-level differences — check pretty-print formatting)\n");
}
out
}
fn normalize_agent_lineages(value: &serde_json::Value) -> serde_json::Value {
fn normalize_agent_string(s: &str) -> String {
let without_prefix = match s.rsplit_once('/') {
Some((_, tail)) => tail,
None => s,
};
match without_prefix.rsplit_once('-') {
Some((head, _)) => format!("{head}-"),
None => without_prefix.to_string(),
}
}
match value {
serde_json::Value::Object(obj) => {
let mut out = serde_json::Map::with_capacity(obj.len());
for (k, v) in obj {
if matches!(k.as_str(), "agent_remote") {
continue;
}
let normalized_v = match k.as_str() {
"agent" => match v {
serde_json::Value::String(s) => {
serde_json::Value::String(normalize_agent_string(s))
}
_ => normalize_agent_lineages(v),
},
"agent_id" | "agent_full_id" => match v {
serde_json::Value::String(_) => {
serde_json::Value::String(String::new())
}
_ => normalize_agent_lineages(v),
},
_ => normalize_agent_lineages(v),
};
out.insert(k.clone(), normalized_v);
}
serde_json::Value::Object(out)
}
serde_json::Value::Array(arr) => serde_json::Value::Array(
arr.iter().map(normalize_agent_lineages).collect(),
),
_ => value.clone(),
}
}
pub fn rounded(value: &serde_json::Value) -> serde_json::Value {
match value {
serde_json::Value::Number(n) => {
if let Some(f) = n.as_f64() {
let s12 = format!("{:.12e}", f);
let f12: f64 = s12.parse().unwrap_or(f);
let s8 = format!("{:.8e}", f12);
let f8: f64 = s8.parse().unwrap_or(f12);
serde_json::Value::Number(
serde_json::Number::from_f64(f8).unwrap_or_else(|| n.clone()),
)
} else {
value.clone()
}
}
serde_json::Value::Array(arr) => {
serde_json::Value::Array(arr.iter().map(rounded).collect())
}
serde_json::Value::Object(obj) => {
serde_json::Value::Object(obj.iter().map(|(k, v)| (k.clone(), rounded(v))).collect())
}
_ => value.clone(),
}
}