use std::sync::Arc;
use async_trait::async_trait;
use serde_json::json;
use super::{Capability, CapabilityStatus};
use crate::atoms::PostToolExecHook;
use crate::tool_types::{ToolCall, ToolDefinition, ToolResult};
use crate::traits::{SessionFileSystem, ToolContext};
use crate::typed_id::SessionId;
const MAX_PERSISTED_STREAM_BYTES: usize = 1024 * 1024; const OUTPUT_DIR: &str = "/outputs";
pub struct PersistResult {
pub stdout_path: Option<String>,
pub stderr_path: Option<String>,
pub stdout_total_lines: usize,
}
pub async fn persist_output(
file_store: &Arc<dyn SessionFileSystem>,
session_id: SessionId,
tool_call_id: &str,
stdout: &str,
stderr: &str,
) -> Option<PersistResult> {
if stdout.is_empty() && stderr.is_empty() {
return None;
}
let safe_id: String = tool_call_id
.chars()
.filter(|c| c.is_alphanumeric() || *c == '-' || *c == '_')
.collect();
if safe_id.is_empty() {
return None;
}
let mut result = PersistResult {
stdout_path: None,
stderr_path: None,
stdout_total_lines: 0,
};
if !stdout.is_empty() {
let stdout_to_persist = truncate_for_persistence(stdout, MAX_PERSISTED_STREAM_BYTES);
let path = format!("{OUTPUT_DIR}/{safe_id}.stdout");
result.stdout_total_lines = stdout.lines().count();
if file_store
.write_file(session_id, &path, &stdout_to_persist, "utf-8")
.await
.is_ok()
{
result.stdout_path = Some(path);
}
}
if !stderr.is_empty() {
let stderr_to_persist = truncate_for_persistence(stderr, MAX_PERSISTED_STREAM_BYTES);
let path = format!("{OUTPUT_DIR}/{safe_id}.stderr");
if file_store
.write_file(session_id, &path, &stderr_to_persist, "utf-8")
.await
.is_ok()
{
result.stderr_path = Some(path);
}
}
if result.stdout_path.is_some() || result.stderr_path.is_some() {
Some(result)
} else {
None
}
}
pub async fn persist_large_output(
file_store: &Arc<dyn SessionFileSystem>,
session_id: SessionId,
tool_call_id: &str,
stdout: &str,
stderr: &str,
) -> Option<PersistResult> {
persist_output(file_store, session_id, tool_call_id, stdout, stderr).await
}
fn truncate_for_persistence(text: &str, max_bytes: usize) -> String {
if text.len() <= max_bytes {
return text.to_string();
}
let mut end = max_bytes;
while end > 0 && !text.is_char_boundary(end) {
end -= 1;
}
let mut truncated = text[..end].to_string();
truncated.push_str("\n\n[output truncated before persistence due to size limit]");
truncated
}
pub fn annotate_truncated_output(truncated: &str, file_path: &str, full_size: usize) -> String {
let size_kb = full_size / 1024;
format!(
"{truncated}\n\n[full output saved to /workspace{file_path} ({size_kb} KiB) — use read_file with offset/limit]"
)
}
pub struct ToolOutputPersistenceCapability;
impl Capability for ToolOutputPersistenceCapability {
fn id(&self) -> &str {
"tool_output_persistence"
}
fn name(&self) -> &str {
"Tool Output Persistence"
}
fn description(&self) -> &str {
"Persists full exec tool output to session VFS before truncation, \
enabling lossless retrieval via read_file or grep."
}
fn status(&self) -> CapabilityStatus {
CapabilityStatus::Available
}
fn dependencies(&self) -> Vec<&'static str> {
vec!["session_file_system"]
}
fn post_tool_exec_hooks(&self) -> Vec<Arc<dyn PostToolExecHook>> {
vec![Arc::new(PersistOutputHook)]
}
}
pub struct PersistOutputHook;
#[async_trait]
impl PostToolExecHook for PersistOutputHook {
async fn after_exec(
&self,
tool_call: &ToolCall,
tool_def: &ToolDefinition,
result: &mut ToolResult,
context: &ToolContext,
) {
if tool_def.hints().persist_output != Some(true) {
return;
}
if result
.result
.as_ref()
.and_then(|value| value.get("output_files"))
.is_some()
{
return;
}
let output_text = result.raw_output.take().unwrap_or_else(|| {
result
.result
.as_ref()
.map(extract_output_text)
.unwrap_or_default()
});
if output_text.is_empty() {
return;
}
let Some(ref file_store) = context.file_store else {
return;
};
let (stdout, stderr) = split_output_streams(&output_text);
if let Some(persist_result) = persist_large_output(
file_store,
context.session_id,
&tool_call.id,
stdout,
stderr,
)
.await
{
if let Some(ref mut json_val) = result.result
&& let Some(obj) = json_val.as_object_mut()
{
let mut output_files = Vec::new();
if let Some(ref path) = persist_result.stdout_path {
if let Some(current_stdout) = obj.get("stdout").and_then(|v| v.as_str()) {
let annotated =
annotate_truncated_output(current_stdout, path, stdout.len());
obj.insert("stdout".to_string(), json!(annotated));
}
output_files.push(format!("/workspace{path}"));
obj.insert("full_output".to_string(), json!(path));
obj.insert(
"total_lines".to_string(),
json!(persist_result.stdout_total_lines),
);
}
if let Some(ref path) = persist_result.stderr_path {
output_files.push(format!("/workspace{path}"));
}
if !output_files.is_empty() {
obj.insert("output_files".to_string(), json!(output_files));
}
}
}
}
}
fn split_output_streams(text: &str) -> (&str, &str) {
const STDERR_SEPARATOR: &str = "\n--- stderr ---\n";
if let Some(pos) = text.rfind(STDERR_SEPARATOR) {
(&text[..pos], &text[pos + STDERR_SEPARATOR.len()..])
} else {
(text, "")
}
}
fn extract_output_text(json: &serde_json::Value) -> String {
let mut parts = Vec::new();
if let Some(stdout) = json.get("stdout").and_then(|v| v.as_str())
&& !stdout.is_empty()
{
parts.push(stdout);
}
if let Some(stderr) = json.get("stderr").and_then(|v| v.as_str())
&& !stderr.is_empty()
{
if !parts.is_empty() {
parts.push("\n--- stderr ---\n");
}
parts.push(stderr);
}
if parts.is_empty()
&& let Some(output) = json.get("output").and_then(|v| v.as_str())
&& !output.is_empty()
{
parts.push(output);
}
parts.join("")
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tool_output_sanitizer::EXEC_OUTPUT_BUDGET;
#[test]
fn test_extract_output_text_stdout_stderr() {
let json = json!({
"stdout": "hello world",
"stderr": "warning: unused variable",
"exit_code": 0
});
let text = extract_output_text(&json);
assert!(text.contains("hello world"));
assert!(text.contains("warning: unused variable"));
assert!(text.contains("--- stderr ---"));
}
#[test]
fn test_extract_output_text_stdout_only() {
let json = json!({
"stdout": "hello",
"stderr": "",
"exit_code": 0
});
assert_eq!(extract_output_text(&json), "hello");
}
#[test]
fn test_extract_output_text_output_field() {
let json = json!({
"output": "combined output",
"exit_code": 0
});
assert_eq!(extract_output_text(&json), "combined output");
}
#[test]
fn test_extract_output_text_empty() {
let json = json!({ "exit_code": 0 });
assert_eq!(extract_output_text(&json), "");
}
#[test]
fn test_extract_output_text_prefers_stdout_over_output() {
let json = json!({
"stdout": "from stdout",
"output": "from output",
"exit_code": 0
});
let text = extract_output_text(&json);
assert!(text.contains("from stdout"));
assert!(!text.contains("from output"));
}
#[test]
fn test_capability_metadata() {
let cap = ToolOutputPersistenceCapability;
assert_eq!(cap.id(), "tool_output_persistence");
assert!(!cap.post_tool_exec_hooks().is_empty());
assert!(cap.dependencies().contains(&"session_file_system"));
}
#[test]
fn test_split_output_streams_both() {
let text = "hello world\n--- stderr ---\nwarning: unused";
let (stdout, stderr) = split_output_streams(text);
assert_eq!(stdout, "hello world");
assert_eq!(stderr, "warning: unused");
}
#[test]
fn test_split_output_streams_stdout_only() {
let text = "hello world\nline two";
let (stdout, stderr) = split_output_streams(text);
assert_eq!(stdout, "hello world\nline two");
assert!(stderr.is_empty());
}
#[test]
fn test_split_output_streams_empty() {
let (stdout, stderr) = split_output_streams("");
assert!(stdout.is_empty());
assert!(stderr.is_empty());
}
#[test]
fn test_annotate_truncated_output() {
let annotated =
annotate_truncated_output("truncated text...", "/outputs/abc.stdout", 50 * 1024);
assert!(annotated.contains("truncated text..."));
assert!(annotated.contains("/workspace/outputs/abc.stdout"));
assert!(annotated.contains("50 KiB"));
assert!(annotated.contains("read_file"));
}
#[test]
fn test_annotate_truncated_output_small() {
let annotated = annotate_truncated_output("short", "/outputs/x.stdout", 1024);
assert!(annotated.contains("1 KiB"));
assert!(annotated.contains("read_file with offset/limit"));
}
#[test]
fn test_truncate_for_persistence_noop_when_under_limit() {
let text = "hello";
assert_eq!(truncate_for_persistence(text, 1024), text);
}
#[test]
fn test_truncate_for_persistence_adds_marker_when_over_limit() {
let text = "x".repeat(128);
let out = truncate_for_persistence(&text, 32);
assert!(out.len() > 32);
assert!(out.starts_with(&"x".repeat(32)));
assert!(out.contains("output truncated before persistence"));
}
use crate::error::Result;
use crate::session_file::{FileInfo, FileStat, GrepMatch, SessionFile};
use crate::traits::SessionFileSystem;
use chrono::Utc;
use std::collections::HashMap;
use std::sync::Mutex;
use uuid::Uuid;
#[derive(Default)]
struct MockFileStore {
files: Mutex<HashMap<String, String>>,
}
impl MockFileStore {
fn content(&self, path: &str) -> Option<String> {
self.files.lock().unwrap().get(path).cloned()
}
}
#[async_trait]
impl SessionFileSystem for MockFileStore {
async fn read_file(
&self,
_session_id: SessionId,
_path: &str,
) -> Result<Option<SessionFile>> {
Ok(None)
}
async fn write_file(
&self,
_session_id: SessionId,
path: &str,
content: &str,
_encoding: &str,
) -> Result<SessionFile> {
self.files
.lock()
.unwrap()
.insert(path.to_string(), content.to_string());
Ok(SessionFile {
id: Uuid::new_v4(),
session_id: Uuid::nil(),
path: path.to_string(),
name: path.rsplit('/').next().unwrap_or("").to_string(),
content: Some(content.to_string()),
encoding: "utf-8".to_string(),
is_directory: false,
is_readonly: false,
size_bytes: content.len() as i64,
created_at: Utc::now(),
updated_at: Utc::now(),
})
}
async fn delete_file(
&self,
_session_id: SessionId,
_path: &str,
_recursive: bool,
) -> Result<bool> {
Ok(false)
}
async fn list_directory(
&self,
_session_id: SessionId,
_path: &str,
) -> Result<Vec<FileInfo>> {
Ok(vec![])
}
async fn stat_file(&self, _session_id: SessionId, _path: &str) -> Result<Option<FileStat>> {
Ok(None)
}
async fn grep_files(
&self,
_session_id: SessionId,
_pattern: &str,
_path_pattern: Option<&str>,
) -> Result<Vec<GrepMatch>> {
Ok(vec![])
}
async fn create_directory(&self, _session_id: SessionId, _path: &str) -> Result<FileInfo> {
Err(anyhow::anyhow!("not implemented").into())
}
}
fn test_session_id() -> SessionId {
SessionId::from(Uuid::nil())
}
#[tokio::test]
async fn test_persist_large_output_persists_stdout_below_budget() {
let store: Arc<dyn SessionFileSystem> = Arc::new(MockFileStore::default());
let small = "a".repeat(EXEC_OUTPUT_BUDGET);
let result = persist_large_output(&store, test_session_id(), "call-1", &small, "").await;
let r = result.expect("should persist non-empty stdout");
assert_eq!(r.stdout_path.as_deref(), Some("/outputs/call-1.stdout"));
}
#[tokio::test]
async fn test_persist_large_output_persists_stdout_above_budget() {
let mock = Arc::new(MockFileStore::default());
let store: Arc<dyn SessionFileSystem> = mock.clone();
let large = "x".repeat(EXEC_OUTPUT_BUDGET + 1);
let result = persist_large_output(&store, test_session_id(), "call-2", &large, "").await;
let r = result.expect("should persist");
assert!(r.stdout_path.is_some());
assert_eq!(r.stdout_path.as_deref(), Some("/outputs/call-2.stdout"));
assert!(r.stderr_path.is_none());
assert_eq!(r.stdout_total_lines, 1);
assert_eq!(
mock.content("/outputs/call-2.stdout").unwrap().len(),
large.len()
);
}
#[tokio::test]
async fn test_persist_large_output_caps_persisted_stdout_size() {
let mock = Arc::new(MockFileStore::default());
let store: Arc<dyn SessionFileSystem> = mock.clone();
let huge = "x".repeat(MAX_PERSISTED_STREAM_BYTES + 2048);
let result = persist_large_output(&store, test_session_id(), "call-cap", &huge, "").await;
assert!(result.is_some());
let persisted = mock.content("/outputs/call-cap.stdout").unwrap();
assert!(persisted.len() < huge.len());
assert!(persisted.contains("output truncated before persistence"));
}
#[tokio::test]
async fn test_persist_large_output_persists_stderr_above_threshold() {
let mock = Arc::new(MockFileStore::default());
let store: Arc<dyn SessionFileSystem> = mock.clone();
let large_stderr = "e".repeat(4097);
let result =
persist_large_output(&store, test_session_id(), "call-3", "small", &large_stderr).await;
let r = result.expect("should persist stderr");
assert_eq!(r.stdout_path.as_deref(), Some("/outputs/call-3.stdout"));
assert_eq!(r.stderr_path.as_deref(), Some("/outputs/call-3.stderr"));
assert_eq!(mock.content("/outputs/call-3.stderr").unwrap().len(), 4097);
}
#[tokio::test]
async fn test_persist_large_output_both_stdout_and_stderr() {
let mock = Arc::new(MockFileStore::default());
let store: Arc<dyn SessionFileSystem> = mock.clone();
let large_stdout = "o".repeat(EXEC_OUTPUT_BUDGET + 100);
let large_stderr = "e".repeat(5000);
let result = persist_large_output(
&store,
test_session_id(),
"call-4",
&large_stdout,
&large_stderr,
)
.await;
let r = result.expect("should persist both");
assert!(r.stdout_path.is_some());
assert!(r.stderr_path.is_some());
}
#[tokio::test]
async fn test_persist_large_output_sanitizes_id() {
let mock = Arc::new(MockFileStore::default());
let store: Arc<dyn SessionFileSystem> = mock.clone();
let large = "x".repeat(EXEC_OUTPUT_BUDGET + 1);
let result =
persist_large_output(&store, test_session_id(), "call/../../../etc", &large, "").await;
let r = result.expect("should persist with sanitized id");
assert_eq!(r.stdout_path.as_deref(), Some("/outputs/calletc.stdout"));
}
#[tokio::test]
async fn test_persist_large_output_empty_id_returns_none() {
let store: Arc<dyn SessionFileSystem> = Arc::new(MockFileStore::default());
let large = "x".repeat(EXEC_OUTPUT_BUDGET + 1);
let result = persist_large_output(&store, test_session_id(), "///...", &large, "").await;
assert!(result.is_none());
}
#[tokio::test]
async fn test_persist_large_output_line_count() {
let mock = Arc::new(MockFileStore::default());
let store: Arc<dyn SessionFileSystem> = mock.clone();
let line = "x".repeat(200);
let lines: Vec<&str> = std::iter::repeat_n(line.as_str(), 100).collect();
let large = lines.join("\n");
assert!(large.len() > EXEC_OUTPUT_BUDGET);
let result =
persist_large_output(&store, test_session_id(), "call-lines", &large, "").await;
let r = result.expect("should persist");
assert_eq!(r.stdout_total_lines, 100);
}
}