use std::io::{BufRead, Write};
use std::sync::{Arc, Mutex};
use tau_agent_base::plugin_protocol::{PluginMessage, PluginRequest, PluginToolResult};
use tau_agent_base::protocol::{Request, Response};
use tau_agent_base::types::{TextContent, ToolResultContent};
#[derive(Clone)]
pub struct SharedStdout<W: Write + Send + 'static = std::io::BufWriter<std::io::Stdout>> {
inner: Arc<Mutex<W>>,
}
impl<W: Write + Send + 'static> SharedStdout<W> {
pub fn new(writer: W) -> Self {
Self {
inner: Arc::new(Mutex::new(writer)),
}
}
pub fn from_arc(inner: Arc<Mutex<W>>) -> Self {
Self { inner }
}
pub fn inner(&self) -> Arc<Mutex<W>> {
Arc::clone(&self.inner)
}
}
impl<W: Write + Send + 'static> Write for SharedStdout<W> {
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
let mut g = self.inner.lock().expect("SharedStdout mutex poisoned");
g.write(buf)
}
fn write_all(&mut self, buf: &[u8]) -> std::io::Result<()> {
let mut g = self.inner.lock().expect("SharedStdout mutex poisoned");
g.write_all(buf)
}
fn flush(&mut self) -> std::io::Result<()> {
let mut g = self.inner.lock().expect("SharedStdout mutex poisoned");
g.flush()
}
}
pub fn send_message(writer: &mut impl Write, msg: &PluginMessage) {
if let Ok(mut line) = serde_json::to_string(msg) {
line.push('\n');
let _ = writer.write_all(line.as_bytes());
let _ = writer.flush();
}
}
pub fn server_request(
writer: &mut impl Write,
reader: &mut impl BufRead,
request: Request,
prefix: &str,
) -> tau_agent_base::Result<Response> {
use std::sync::atomic::{AtomicU64, Ordering};
static SEQ: AtomicU64 = AtomicU64::new(0);
let seq = SEQ.fetch_add(1, Ordering::Relaxed);
let request_id = format!(
"{}-{}-{}",
prefix,
tau_agent_base::types::timestamp_ms(),
seq
);
send_message(
writer,
&PluginMessage::ServerRequest {
request_id: request_id.clone(),
request,
},
);
let mut line = String::new();
loop {
line.clear();
match reader.read_line(&mut line) {
Ok(0) => {
return Err(tau_agent_base::Error::Io(
"stdin closed while waiting for server response".into(),
));
}
Ok(_) => {}
Err(e) => {
return Err(tau_agent_base::Error::Io(format!("read error: {}", e)));
}
}
if line.trim().is_empty() {
continue;
}
let req: PluginRequest = match serde_json::from_str(&line) {
Ok(r) => r,
Err(_) => continue,
};
match req {
PluginRequest::ServerResponse {
request_id: rid,
response,
} if rid == request_id => {
return Ok(response);
}
PluginRequest::ToolCall { tool_call_id, .. } => {
send_message(
writer,
&PluginMessage::ToolResult(PluginToolResult {
tool_call_id,
content: vec![ToolResultContent::Text(TextContent {
text: "plugin is busy with a background operation — please retry \
in a moment"
.into(),
text_signature: None,
})],
is_error: true,
summary: None,
post_persist_actions: Vec::new(),
}),
);
}
_ => {}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::{Arc, Mutex};
use std::thread;
#[test]
fn shared_stdout_writes_are_atomic_across_threads() {
const THREADS: usize = 8;
const WRITES_PER_THREAD: usize = 100;
let sink: Arc<Mutex<Vec<u8>>> = Arc::new(Mutex::new(Vec::new()));
let writer = SharedStdout::from_arc(Arc::clone(&sink));
let handles: Vec<_> = (0..THREADS)
.map(|tid| {
let mut w = writer.clone();
thread::spawn(move || {
let line = format!("thread-{:03}-payload\n", tid);
for _ in 0..WRITES_PER_THREAD {
w.write_all(line.as_bytes()).expect("write");
}
})
})
.collect();
for h in handles {
h.join().expect("join");
}
let bytes = sink.lock().expect("sink").clone();
let text = String::from_utf8(bytes).expect("utf8");
let mut count = 0usize;
for line in text.lines() {
assert!(
line.starts_with("thread-") && line.ends_with("-payload"),
"interleaved or truncated line: {:?}",
line
);
count += 1;
}
assert_eq!(count, THREADS * WRITES_PER_THREAD);
}
#[test]
fn server_request_ids_are_unique_under_burst() {
use std::collections::HashSet;
use std::sync::atomic::{AtomicU64, Ordering};
static SEQ: AtomicU64 = AtomicU64::new(0);
let ts = 1_700_000_000_000u64;
let mut seen = HashSet::new();
for _ in 0..1000 {
let s = SEQ.fetch_add(1, Ordering::Relaxed);
let id = format!("merge-sr-{}-{}", ts, s);
assert!(seen.insert(id), "duplicate id generated");
}
}
}