adk_studio/server/
sse.rs

1use crate::server::state::AppState;
2use axum::{
3    extract::{Path, Query, State},
4    response::sse::{Event, Sse},
5};
6use futures::Stream;
7use serde::Deserialize;
8use std::collections::HashMap;
9use std::convert::Infallible;
10use std::sync::Arc;
11use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader, BufWriter};
12use tokio::process::{Child, Command};
13use tokio::sync::Mutex;
14
15lazy_static::lazy_static! {
16    static ref SESSIONS: Arc<Mutex<HashMap<String, SessionProcess>>> = Arc::new(Mutex::new(HashMap::new()));
17}
18
19struct SessionProcess {
20    stdin: BufWriter<tokio::process::ChildStdin>,
21    stdout_rx: tokio::sync::mpsc::Receiver<String>,
22    stderr_rx: tokio::sync::mpsc::Receiver<String>,
23    _child: Child,
24}
25
26#[derive(Deserialize)]
27pub struct StreamQuery {
28    input: String,
29    #[serde(default)]
30    api_key: Option<String>,
31    #[serde(default)]
32    binary_path: Option<String>,
33    #[serde(default)]
34    session_id: Option<String>,
35}
36
37async fn get_or_create_session(
38    session_id: &str,
39    binary_path: &str,
40    api_key: &str,
41) -> Result<(), String> {
42    let mut sessions = SESSIONS.lock().await;
43    if sessions.contains_key(session_id) {
44        return Ok(());
45    }
46
47    let mut child = Command::new(binary_path)
48        .arg(session_id)
49        .env("GOOGLE_API_KEY", api_key)
50        .stdin(std::process::Stdio::piped())
51        .stdout(std::process::Stdio::piped())
52        .stderr(std::process::Stdio::piped())
53        .spawn()
54        .map_err(|e| format!("Failed to start binary: {}", e))?;
55
56    let stdin = BufWriter::new(child.stdin.take().unwrap());
57    let stdout = child.stdout.take().unwrap();
58    let stderr = child.stderr.take().unwrap();
59
60    let (stdout_tx, stdout_rx) = tokio::sync::mpsc::channel(100);
61    tokio::spawn(async move {
62        let mut reader = BufReader::new(stdout).lines();
63        while let Ok(Some(line)) = reader.next_line().await {
64            if stdout_tx.send(line).await.is_err() {
65                break;
66            }
67        }
68    });
69
70    let (stderr_tx, stderr_rx) = tokio::sync::mpsc::channel(100);
71    tokio::spawn(async move {
72        let mut reader = BufReader::new(stderr).lines();
73        while let Ok(Some(line)) = reader.next_line().await {
74            if stderr_tx.send(line).await.is_err() {
75                break;
76            }
77        }
78    });
79
80    sessions.insert(
81        session_id.to_string(),
82        SessionProcess { stdin, stdout_rx, stderr_rx, _child: child },
83    );
84    Ok(())
85}
86
87pub async fn stream_handler(
88    Path(_id): Path<String>,
89    Query(query): Query<StreamQuery>,
90    State(_state): State<AppState>,
91) -> Sse<impl Stream<Item = Result<Event, Infallible>>> {
92    let api_key =
93        query.api_key.or_else(|| std::env::var("GOOGLE_API_KEY").ok()).unwrap_or_default();
94    let input = query.input;
95    let binary_path = query.binary_path;
96    let session_id = query.session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
97
98    let stream = async_stream::stream! {
99        let Some(bin_path) = binary_path else {
100            yield Ok(Event::default().event("error").data("No binary available. Click 'Build' first."));
101            return;
102        };
103
104        if let Err(e) = get_or_create_session(&session_id, &bin_path, &api_key).await {
105            yield Ok(Event::default().event("error").data(e));
106            return;
107        }
108
109        yield Ok(Event::default().event("session").data(session_id.clone()));
110
111        // Send input
112        {
113            let mut sessions = SESSIONS.lock().await;
114            if let Some(session) = sessions.get_mut(&session_id) {
115                if session.stdin.write_all(format!("{}\n", input).as_bytes()).await.is_err()
116                    || session.stdin.flush().await.is_err() {
117                    yield Ok(Event::default().event("error").data("Failed to send input"));
118                    return;
119                }
120            }
121        }
122
123        let timeout = tokio::time::Duration::from_secs(60);
124        let start = tokio::time::Instant::now();
125
126        loop {
127            if start.elapsed() > timeout {
128                yield Ok(Event::default().event("error").data("Timeout"));
129                break;
130            }
131
132            let (stdout_msg, stderr_msg) = {
133                let mut sessions = SESSIONS.lock().await;
134                match sessions.get_mut(&session_id) {
135                    Some(s) => (s.stdout_rx.try_recv().ok(), s.stderr_rx.try_recv().ok()),
136                    None => {
137                        yield Ok(Event::default().event("error").data("Session lost"));
138                        break;
139                    }
140                }
141            };
142
143            let mut got_data = false;
144
145            if let Some(line) = stdout_msg {
146                got_data = true;
147                let line = line.trim_start_matches("> ");
148                if let Some(sid) = line.strip_prefix("SESSION:") {
149                    yield Ok(Event::default().event("session").data(sid));
150                } else if let Some(trace) = line.strip_prefix("TRACE:") {
151                    yield Ok(Event::default().event("trace").data(trace));
152                } else if let Some(chunk) = line.strip_prefix("CHUNK:") {
153                    // Streaming chunk - emit immediately
154                    let decoded = serde_json::from_str::<String>(chunk).unwrap_or_else(|_| chunk.to_string());
155                    yield Ok(Event::default().event("chunk").data(decoded));
156                } else if let Some(response) = line.strip_prefix("RESPONSE:") {
157                    let decoded = serde_json::from_str::<String>(response).unwrap_or_else(|_| response.to_string());
158                    yield Ok(Event::default().event("chunk").data(decoded));
159                    yield Ok(Event::default().event("end").data(""));
160                    break;
161                }
162            }
163
164            if let Some(line) = stderr_msg {
165                got_data = true;
166                if let Ok(json) = serde_json::from_str::<serde_json::Value>(&line) {
167                    let fields = json.get("fields");
168                    let msg = fields.and_then(|f| f.get("message")).and_then(|m| m.as_str()).unwrap_or("");
169
170                    if msg == "tool_call" {
171                        let name = fields.and_then(|f| f.get("tool.name")).and_then(|v| v.as_str()).unwrap_or("");
172                        let args = fields.and_then(|f| f.get("tool.args")).and_then(|v| v.as_str()).unwrap_or("{}");
173                        yield Ok(Event::default().event("tool_call").data(serde_json::json!({"name": name, "args": args}).to_string()));
174                    } else if msg == "tool_result" {
175                        let name = fields.and_then(|f| f.get("tool.name")).and_then(|v| v.as_str()).unwrap_or("");
176                        let result = fields.and_then(|f| f.get("tool.result")).and_then(|v| v.as_str()).unwrap_or("");
177                        yield Ok(Event::default().event("tool_result").data(serde_json::json!({"name": name, "result": result}).to_string()));
178                    } else if msg == "Starting agent execution" {
179                        // Emit node_start for sub-agent
180                        let agent = json.get("span").and_then(|s| s.get("agent.name")).and_then(|v| v.as_str()).unwrap_or("");
181                        yield Ok(Event::default().event("trace").data(serde_json::json!({"type": "node_start", "node": agent, "step": 0}).to_string()));
182                    } else if msg == "Agent execution complete" {
183                        // Emit node_end for sub-agent - agent name is in fields
184                        let agent = fields.and_then(|f| f.get("agent.name")).and_then(|v| v.as_str()).unwrap_or("");
185                        yield Ok(Event::default().event("trace").data(serde_json::json!({"type": "node_end", "node": agent, "step": 0, "duration_ms": 0}).to_string()));
186                    } else if msg == "Generating content" {
187                        // Model call - extract details
188                        let span = json.get("span");
189                        let model = span.and_then(|s| s.get("model.name")).and_then(|v| v.as_str()).unwrap_or("");
190                        let tools = span.and_then(|s| s.get("request.tools_count")).and_then(|v| v.as_str()).unwrap_or("0");
191                        yield Ok(Event::default().event("log").data(serde_json::json!({"message": format!("Calling {} (tools: {})", model, tools)}).to_string()));
192                    }
193                }
194            }
195
196            if !got_data {
197                tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
198            }
199        }
200    };
201
202    Sse::new(stream)
203}
204
205pub async fn kill_session(Path(session_id): Path<String>) -> &'static str {
206    SESSIONS.lock().await.remove(&session_id);
207    "ok"
208}