1use crate::server::state::AppState;
2use axum::{
3 extract::{Path, Query, State},
4 response::sse::{Event, Sse},
5};
6use futures::Stream;
7use serde::Deserialize;
8use std::collections::HashMap;
9use std::convert::Infallible;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
12use tokio::process::{Child, Command};
13use tokio::sync::Mutex;
14
15lazy_static::lazy_static! {
16 static ref SESSIONS: Arc<Mutex<HashMap<String, SessionProcess>>> = Arc::new(Mutex::new(HashMap::new()));
17}
18
19struct SessionProcess {
20 stdin: BufWriter<tokio::process::ChildStdin>,
21 stdout_rx: tokio::sync::mpsc::Receiver<String>,
22 stderr_rx: tokio::sync::mpsc::Receiver<String>,
23 _child: Child,
24}
25
26#[derive(Deserialize)]
27pub struct StreamQuery {
28 input: String,
29 #[serde(default)]
30 api_key: Option<String>,
31 #[serde(default)]
32 binary_path: Option<String>,
33 #[serde(default)]
34 session_id: Option<String>,
35}
36
37async fn get_or_create_session(
38 session_id: &str,
39 binary_path: &str,
40 api_key: &str,
41) -> Result<(), String> {
42 let mut sessions = SESSIONS.lock().await;
43 if sessions.contains_key(session_id) {
44 return Ok(());
45 }
46
47 let mut child = Command::new(binary_path)
48 .arg(session_id)
49 .env("GOOGLE_API_KEY", api_key)
50 .stdin(std::process::Stdio::piped())
51 .stdout(std::process::Stdio::piped())
52 .stderr(std::process::Stdio::piped())
53 .spawn()
54 .map_err(|e| format!("Failed to start binary: {}", e))?;
55
56 let stdin = BufWriter::new(child.stdin.take().unwrap());
57 let stdout = child.stdout.take().unwrap();
58 let stderr = child.stderr.take().unwrap();
59
60 let (stdout_tx, stdout_rx) = tokio::sync::mpsc::channel(100);
61 tokio::spawn(async move {
62 let mut reader = BufReader::new(stdout).lines();
63 while let Ok(Some(line)) = reader.next_line().await {
64 if stdout_tx.send(line).await.is_err() {
65 break;
66 }
67 }
68 });
69
70 let (stderr_tx, stderr_rx) = tokio::sync::mpsc::channel(100);
71 tokio::spawn(async move {
72 let mut reader = BufReader::new(stderr).lines();
73 while let Ok(Some(line)) = reader.next_line().await {
74 if stderr_tx.send(line).await.is_err() {
75 break;
76 }
77 }
78 });
79
80 sessions.insert(
81 session_id.to_string(),
82 SessionProcess { stdin, stdout_rx, stderr_rx, _child: child },
83 );
84 Ok(())
85}
86
87pub async fn stream_handler(
88 Path(_id): Path<String>,
89 Query(query): Query<StreamQuery>,
90 State(_state): State<AppState>,
91) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
92 let api_key =
93 query.api_key.or_else(|| std::env::var("GOOGLE_API_KEY").ok()).unwrap_or_default();
94 let input = query.input;
95 let binary_path = query.binary_path;
96 let session_id = query.session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
97
98 let stream = async_stream::stream! {
99 let Some(bin_path) = binary_path else {
100 yield Ok(Event::default().event("error").data("No binary available. Click 'Build' first."));
101 return;
102 };
103
104 if let Err(e) = get_or_create_session(&session_id, &bin_path, &api_key).await {
105 yield Ok(Event::default().event("error").data(e));
106 return;
107 }
108
109 yield Ok(Event::default().event("session").data(session_id.clone()));
110
111 {
113 let mut sessions = SESSIONS.lock().await;
114 if let Some(session) = sessions.get_mut(&session_id) {
115 if session.stdin.write_all(format!("{}\n", input).as_bytes()).await.is_err()
116 || session.stdin.flush().await.is_err() {
117 yield Ok(Event::default().event("error").data("Failed to send input"));
118 return;
119 }
120 }
121 }
122
123 let timeout = tokio::time::Duration::from_secs(60);
124 let start = tokio::time::Instant::now();
125
126 loop {
127 if start.elapsed() > timeout {
128 yield Ok(Event::default().event("error").data("Timeout"));
129 break;
130 }
131
132 let (stdout_msg, stderr_msg) = {
133 let mut sessions = SESSIONS.lock().await;
134 match sessions.get_mut(&session_id) {
135 Some(s) => (s.stdout_rx.try_recv().ok(), s.stderr_rx.try_recv().ok()),
136 None => {
137 yield Ok(Event::default().event("error").data("Session lost"));
138 break;
139 }
140 }
141 };
142
143 let mut got_data = false;
144
145 if let Some(line) = stdout_msg {
146 got_data = true;
147 let line = line.trim_start_matches("> ");
148 if let Some(sid) = line.strip_prefix("SESSION:") {
149 yield Ok(Event::default().event("session").data(sid));
150 } else if let Some(trace) = line.strip_prefix("TRACE:") {
151 yield Ok(Event::default().event("trace").data(trace));
152 } else if let Some(chunk) = line.strip_prefix("CHUNK:") {
153 let decoded = serde_json::from_str::<String>(chunk).unwrap_or_else(|_| chunk.to_string());
155 yield Ok(Event::default().event("chunk").data(decoded));
156 } else if let Some(response) = line.strip_prefix("RESPONSE:") {
157 let decoded = serde_json::from_str::<String>(response).unwrap_or_else(|_| response.to_string());
158 yield Ok(Event::default().event("chunk").data(decoded));
159 yield Ok(Event::default().event("end").data(""));
160 break;
161 }
162 }
163
164 if let Some(line) = stderr_msg {
165 got_data = true;
166 if let Ok(json) = serde_json::from_str::<serde_json::Value>(&line) {
167 let fields = json.get("fields");
168 let msg = fields.and_then(|f| f.get("message")).and_then(|m| m.as_str()).unwrap_or("");
169
170 if msg == "tool_call" {
171 let name = fields.and_then(|f| f.get("tool.name")).and_then(|v| v.as_str()).unwrap_or("");
172 let args = fields.and_then(|f| f.get("tool.args")).and_then(|v| v.as_str()).unwrap_or("{}");
173 yield Ok(Event::default().event("tool_call").data(serde_json::json!({"name": name, "args": args}).to_string()));
174 } else if msg == "tool_result" {
175 let name = fields.and_then(|f| f.get("tool.name")).and_then(|v| v.as_str()).unwrap_or("");
176 let result = fields.and_then(|f| f.get("tool.result")).and_then(|v| v.as_str()).unwrap_or("");
177 yield Ok(Event::default().event("tool_result").data(serde_json::json!({"name": name, "result": result}).to_string()));
178 } else if msg == "Starting agent execution" {
179 let agent = json.get("span").and_then(|s| s.get("agent.name")).and_then(|v| v.as_str()).unwrap_or("");
181 yield Ok(Event::default().event("trace").data(serde_json::json!({"type": "node_start", "node": agent, "step": 0}).to_string()));
182 } else if msg == "Agent execution complete" {
183 let agent = fields.and_then(|f| f.get("agent.name")).and_then(|v| v.as_str()).unwrap_or("");
185 yield Ok(Event::default().event("trace").data(serde_json::json!({"type": "node_end", "node": agent, "step": 0, "duration_ms": 0}).to_string()));
186 } else if msg == "Generating content" {
187 let span = json.get("span");
189 let model = span.and_then(|s| s.get("model.name")).and_then(|v| v.as_str()).unwrap_or("");
190 let tools = span.and_then(|s| s.get("request.tools_count")).and_then(|v| v.as_str()).unwrap_or("0");
191 yield Ok(Event::default().event("log").data(serde_json::json!({"message": format!("Calling {} (tools: {})", model, tools)}).to_string()));
192 }
193 }
194 }
195
196 if !got_data {
197 tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
198 }
199 }
200 };
201
202 Sse::new(stream)
203}
204
205pub async fn kill_session(Path(session_id): Path<String>) -> &'static str {
206 SESSIONS.lock().await.remove(&session_id);
207 "ok"
208}