use sage_core::agent::{ExecutionOutcome, UnifiedExecutor};
use sage_core::config::Config;
use sage_core::error::{SageError, SageResult};
use sage_core::output::{CostInfo, OutputEvent, OutputFormat, OutputWriter};
use sage_core::types::TaskMetadata;
use std::io::stdout;
use std::path::PathBuf;
use super::args::UnifiedArgs;
pub async fn execute_stream_json(
args: UnifiedArgs,
mut executor: UnifiedExecutor,
config: Config,
working_dir: PathBuf,
) -> SageResult<()> {
let mut writer = OutputWriter::new(stdout(), OutputFormat::StreamJson);
writer
.write_event(&OutputEvent::system("Sage Agent starting"))
.ok();
let task_description = match args.task {
Some(task) => {
if let Ok(task_path) = std::path::Path::new(&task).canonicalize() {
if task_path.is_file() {
writer
.write_event(&OutputEvent::system(format!(
"Loading task from file: {}",
task_path.display()
)))
.ok();
tokio::fs::read_to_string(&task_path)
.await
.map_err(|e| SageError::config(format!("Failed to read task file: {e}")))?
} else {
task
}
} else {
task
}
}
None => {
writer
.write_event(&OutputEvent::error("No task provided for stream mode"))
.ok();
return Err(SageError::config(
"Stream JSON mode requires a task. Use: sage --stream-json \"your task\"",
));
}
};
writer
.write_event(&OutputEvent::system(format!(
"Task: {}",
&task_description[..task_description.len().min(100)]
)))
.ok();
let session_recorder = if config.trajectory.is_enabled() {
let recorder = sage_core::trajectory::init_session_recorder(&working_dir);
if let Some(ref r) = recorder {
executor.set_session_recorder(r.clone());
}
recorder
} else {
None
};
let task = TaskMetadata::new(&task_description, &working_dir.display().to_string());
let start_time = std::time::Instant::now();
let outcome = executor.execute(task).await;
let duration = start_time.elapsed();
let session_id = if let Some(recorder) = &session_recorder {
Some(recorder.lock().await.session_id().to_string())
} else {
None
};
match outcome {
Ok(ref execution_outcome) => {
let execution = execution_outcome.execution();
let mut cost = CostInfo::new(
usize::try_from(execution.total_usage.input_tokens).unwrap_or(usize::MAX),
usize::try_from(execution.total_usage.output_tokens).unwrap_or(usize::MAX),
);
if let Some(cache_read) = execution.total_usage.cache_read_tokens {
cost = cost.with_cache_read(usize::try_from(cache_read).unwrap_or(usize::MAX));
}
if let Some(cache_write) = execution.total_usage.cache_write_tokens {
cost = cost.with_cache_creation(usize::try_from(cache_write).unwrap_or(usize::MAX));
}
let result_content = match execution_outcome {
ExecutionOutcome::Success(_) => execution
.final_result
.clone()
.unwrap_or_else(|| "Task completed successfully".to_string()),
ExecutionOutcome::Failed { error, .. } => {
format!("Error: {}", error.message)
}
ExecutionOutcome::Interrupted { .. } => "Task interrupted by user".to_string(),
ExecutionOutcome::MaxStepsReached { .. } => {
"Task reached maximum steps".to_string()
}
ExecutionOutcome::UserCancelled { .. } => "Task cancelled by user".to_string(),
ExecutionOutcome::NeedsUserInput { last_response, .. } => {
format!("Waiting for input: {}", last_response)
}
};
let result_event = match OutputEvent::result(&result_content) {
OutputEvent::Result(mut e) => {
e.duration_ms = u64::try_from(duration.as_millis()).unwrap_or(u64::MAX);
e.cost = Some(cost);
if let Some(id) = session_id {
e.session_id = Some(id);
}
OutputEvent::Result(e)
}
other => other,
};
writer.write_event(&result_event).ok();
}
Err(ref e) => {
writer.write_event(&OutputEvent::error(e.to_string())).ok();
}
}
outcome.map(|_| ())
}