1use anyhow::Result;
4use base64::{engine::general_purpose::STANDARD, Engine as _};
5use reqwest::Client;
6use std::time::Duration;
7use tokio::time::sleep;
8use tokio_util::sync::CancellationToken;
9
10use crate::approval::types::{PermissionRequest, QuestionRequest};
11
12#[derive(Debug, Clone)]
14pub enum SseEvent {
15 PermissionAsked(PermissionRequest),
16 PermissionReplied {
17 session_id: String,
18 request_id: String,
19 reply: String,
20 },
21 QuestionAsked(QuestionRequest),
22 QuestionReplied {
23 session_id: String,
24 request_id: String,
25 answers: Vec<Vec<String>>,
26 },
27 QuestionRejected {
28 session_id: String,
29 request_id: String,
30 },
31 Connected,
32 Disconnected(Option<String>),
33}
34
35pub struct OpenCodeEvents {
37 base_url: String,
38 password: Option<String>,
39 sender: tokio::sync::mpsc::UnboundedSender<SseEvent>,
40}
41
42impl OpenCodeEvents {
43 pub fn new(
44 base_url: String,
45 password: Option<String>,
46 sender: tokio::sync::mpsc::UnboundedSender<SseEvent>,
47 ) -> Self {
48 OpenCodeEvents {
49 base_url,
50 password,
51 sender,
52 }
53 }
54
55 pub fn start(&self, cancel: CancellationToken) -> tokio::task::JoinHandle<()> {
59 let base_url = self.base_url.clone();
60 let password = self.password.clone();
61 let sender = self.sender.clone();
62
63 tokio::spawn(async move {
64 let mut delay_secs: u64 = 1;
65
66 loop {
67 if cancel.is_cancelled() {
68 break;
69 }
70
71 match connect_and_stream(&base_url, &password, &sender, &cancel).await {
72 Ok(()) => {
73 break;
75 }
76 Err(e) => {
77 let _ = sender.send(SseEvent::Disconnected(Some(e.to_string())));
78
79 if cancel.is_cancelled() {
80 break;
81 }
82
83 tokio::select! {
87 _ = cancel.cancelled() => break,
88 _ = sleep(Duration::from_secs(delay_secs)) => {}
89 }
90 delay_secs = next_reconnect_delay(delay_secs);
91 }
92 }
93 }
94 })
95 }
96}
97
98async fn connect_and_stream(
100 base_url: &str,
101 password: &Option<String>,
102 sender: &tokio::sync::mpsc::UnboundedSender<SseEvent>,
103 cancel: &CancellationToken,
104) -> Result<()> {
105 let url = format!("{}/event", base_url);
106 let client = Client::new();
107
108 let mut req = client
109 .get(&url)
110 .header("Accept", "text/event-stream")
111 .header("Cache-Control", "no-cache");
112
113 if let Some(pw) = password {
114 let creds = format!(":{}", pw);
115 req = req.header("Authorization", format!("Basic {}", STANDARD.encode(creds)));
116 }
117
118 let response = req.send().await?;
119
120 if !response.status().is_success() {
121 anyhow::bail!("SSE connection failed with status {}", response.status());
122 }
123
124 let _ = sender.send(SseEvent::Connected);
126
127 use futures::StreamExt;
128 let mut stream = response.bytes_stream();
129
130 let mut buffer = String::new();
131
132 loop {
133 tokio::select! {
134 _ = cancel.cancelled() => {
135 return Ok(());
136 }
137 chunk = stream.next() => {
138 match chunk {
139 None => {
140 anyhow::bail!("SSE stream ended unexpectedly");
142 }
143 Some(Err(e)) => {
144 anyhow::bail!("SSE stream error: {}", e);
145 }
146 Some(Ok(bytes)) => {
147 let text = String::from_utf8_lossy(&bytes);
148 buffer.push_str(&text);
149
150 while let Some(pos) = buffer.find("\n\n") {
152 let block = buffer[..pos].to_string();
153 buffer = buffer[pos + 2..].to_string();
154 if let Some(event) = parse_sse_block(&block) {
155 let _ = sender.send(event);
156 }
157 }
158 }
159 }
160 }
161 }
162 }
163}
164
165pub fn parse_sse_block(block: &str) -> Option<SseEvent> {
170 let data = block
172 .lines()
173 .find(|line| line.starts_with("data:"))
174 .map(|line| line.trim_start_matches("data:").trim());
175
176 let data = match data {
177 Some(d) if !d.is_empty() => d,
178 _ => return None,
179 };
180
181 let json: serde_json::Value = match serde_json::from_str(data) {
183 Ok(v) => v,
184 Err(_) => return None, };
186
187 let event_type = json.get("type").and_then(|v| v.as_str())?;
188
189 let props = json
190 .get("properties")
191 .cloned()
192 .unwrap_or(serde_json::Value::Null);
193
194 match event_type {
195 "server.connected" => Some(SseEvent::Connected),
196 "server.heartbeat" => None,
197 "permission.asked" => {
198 serde_json::from_value::<PermissionRequest>(props)
199 .ok()
200 .map(SseEvent::PermissionAsked)
201 }
202 "permission.replied" => {
203 let session_id = props
204 .get("session_id")
205 .and_then(|v| v.as_str())
206 .unwrap_or("")
207 .to_string();
208 let request_id = props
209 .get("request_id")
210 .and_then(|v| v.as_str())
211 .unwrap_or("")
212 .to_string();
213 let reply = props
214 .get("reply")
215 .and_then(|v| v.as_str())
216 .unwrap_or("")
217 .to_string();
218 Some(SseEvent::PermissionReplied {
219 session_id,
220 request_id,
221 reply,
222 })
223 }
224 "question.asked" => {
225 serde_json::from_value::<QuestionRequest>(props)
226 .ok()
227 .map(SseEvent::QuestionAsked)
228 }
229 "question.replied" => {
230 let session_id = props
231 .get("session_id")
232 .and_then(|v| v.as_str())
233 .unwrap_or("")
234 .to_string();
235 let request_id = props
236 .get("request_id")
237 .and_then(|v| v.as_str())
238 .unwrap_or("")
239 .to_string();
240 let answers = props
241 .get("answers")
242 .and_then(|v| serde_json::from_value::<Vec<Vec<String>>>(v.clone()).ok())
243 .unwrap_or_default();
244 Some(SseEvent::QuestionReplied {
245 session_id,
246 request_id,
247 answers,
248 })
249 }
250 "question.rejected" => {
251 let session_id = props
252 .get("session_id")
253 .and_then(|v| v.as_str())
254 .unwrap_or("")
255 .to_string();
256 let request_id = props
257 .get("request_id")
258 .and_then(|v| v.as_str())
259 .unwrap_or("")
260 .to_string();
261 Some(SseEvent::QuestionRejected {
262 session_id,
263 request_id,
264 })
265 }
266 _ => None,
267 }
268}
269
270pub fn next_reconnect_delay(current: u64) -> u64 {
272 (current * 2).min(30)
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn test_parse_connected() {
281 let event = parse_sse_block(
282 "data: {\"type\":\"server.connected\",\"properties\":{}}",
283 );
284 assert!(matches!(event, Some(SseEvent::Connected)));
285 }
286
287 #[test]
288 fn test_parse_heartbeat_ignored() {
289 assert!(parse_sse_block(
290 "data: {\"type\":\"server.heartbeat\",\"properties\":{}}"
291 ).is_none());
292 }
293
294 #[test]
295 fn test_parse_malformed_json() {
296 assert!(parse_sse_block("data: not-valid-json").is_none());
297 }
298
299 #[test]
300 fn test_parse_empty_data() {
301 assert!(parse_sse_block("event: ping\n").is_none());
302 }
303
304 #[test]
305 fn test_parse_unknown_type() {
306 assert!(parse_sse_block(
307 "data: {\"type\":\"unknown.event\",\"properties\":{}}"
308 ).is_none());
309 }
310
311 #[test]
312 fn test_parse_permission_asked() {
313 let json = r#"data: {"type":"permission.asked","properties":{"id":"test-id","session_id":"sess","permission":"bash","patterns":[],"metadata":{},"always":[],"tool":null}}"#;
314 let event = parse_sse_block(json).unwrap();
315 assert!(matches!(event, SseEvent::PermissionAsked(ref req) if req.id == "test-id"));
316 }
317
318 #[test]
319 fn test_parse_question_asked() {
320 let json = r#"data: {"type":"question.asked","properties":{"id":"q1","session_id":"s1","questions":[{"question":"What?","header":"H","options":[],"multiple":false,"custom":true}]}}"#;
321 let event = parse_sse_block(json).unwrap();
322 assert!(matches!(event, SseEvent::QuestionAsked(ref req) if req.id == "q1"));
323 }
324
325 #[test]
326 fn test_parse_permission_replied() {
327 let json = r#"data: {"type":"permission.replied","properties":{"session_id":"s1","request_id":"r1","reply":"once"}}"#;
328 let event = parse_sse_block(json).unwrap();
329 assert!(
330 matches!(event, SseEvent::PermissionReplied { ref session_id, ref request_id, ref reply }
331 if session_id == "s1" && request_id == "r1" && reply == "once")
332 );
333 }
334
335 #[test]
336 fn test_parse_question_replied() {
337 let json = r#"data: {"type":"question.replied","properties":{"session_id":"s1","request_id":"r1","answers":[["yes","no"]]}}"#;
338 let event = parse_sse_block(json).unwrap();
339 assert!(
340 matches!(event, SseEvent::QuestionReplied { ref session_id, ref request_id, ref answers }
341 if session_id == "s1" && request_id == "r1" && answers == &vec![vec!["yes".to_string(), "no".to_string()]])
342 );
343 }
344
345 #[test]
346 fn test_parse_question_rejected() {
347 let json = r#"data: {"type":"question.rejected","properties":{"session_id":"s1","request_id":"r1"}}"#;
348 let event = parse_sse_block(json).unwrap();
349 assert!(
350 matches!(event, SseEvent::QuestionRejected { ref session_id, ref request_id }
351 if session_id == "s1" && request_id == "r1")
352 );
353 }
354
355 #[test]
356 fn test_backoff_calculation() {
357 let mut delay: u64 = 1;
358 let sequence: Vec<u64> = (0..8)
359 .map(|_| {
360 let d = delay;
361 delay = next_reconnect_delay(delay);
362 d
363 })
364 .collect();
365 assert_eq!(sequence, vec![1, 2, 4, 8, 16, 30, 30, 30]);
366 }
367
368 #[test]
369 fn test_parse_no_type_field() {
370 assert!(parse_sse_block("data: {\"properties\":{}}").is_none());
371 }
372
373 #[test]
374 fn test_parse_missing_properties() {
375 let event = parse_sse_block("data: {\"type\":\"server.connected\"}");
376 assert!(matches!(event, Some(SseEvent::Connected)));
377 }
378}