1use std::future::Future;
2use std::pin::Pin;
3use std::sync::{Arc, OnceLock};
4use std::time::Duration;
5
6use chimera_core::*;
7
8use crate::config::{OpenCodeConfig, OpenCodeProvider};
9use crate::translate;
10use crate::wire::*;
11
12const DEFAULT_TIMEOUT_SECS: u64 = 300;
13const DEFAULT_MAX_RETRIES: u32 = 2;
14const DEFAULT_MODEL: &str = "glm-5";
15const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
16const MAX_RETRY_DELAY: Duration = Duration::from_secs(120);
17
18pub struct OpenCodeSession {
19 api_key: Option<String>,
20 provider: OpenCodeProvider,
21 session_id: Arc<OnceLock<String>>,
22 active_interrupt: Arc<tokio::sync::Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
23 config: SessionConfig<OpenCodeConfig>,
24 history: Arc<tokio::sync::Mutex<Vec<WireMessage>>>,
25 client: Option<reqwest::Client>,
26}
27
28impl OpenCodeSession {
29 pub(crate) fn new(
30 api_key: Option<String>,
31 provider: OpenCodeProvider,
32 config: SessionConfig<OpenCodeConfig>,
33 ) -> Self {
34 Self {
35 api_key,
36 provider,
37 session_id: Arc::new(OnceLock::new()),
38 active_interrupt: Arc::new(tokio::sync::Mutex::new(None)),
39 config,
40 history: Arc::new(tokio::sync::Mutex::new(Vec::new())),
41 client: None,
42 }
43 }
44
45 pub(crate) fn with_session_id(self, id: String) -> Self {
46 let _ = self.session_id.set(id);
47 self
48 }
49
50 #[cfg(test)]
51 pub(crate) fn api_key(&self) -> Option<&str> {
52 self.api_key.as_deref()
53 }
54
55 fn model(&self) -> &str {
56 self.config.model.as_deref().unwrap_or(DEFAULT_MODEL)
57 }
58
59 fn timeout(&self) -> Duration {
60 Duration::from_secs(
61 self.config
62 .backend
63 .timeout_secs
64 .unwrap_or(DEFAULT_TIMEOUT_SECS),
65 )
66 }
67
68 fn max_retries(&self) -> u32 {
69 self.config
70 .backend
71 .max_retries
72 .unwrap_or(DEFAULT_MAX_RETRIES)
73 }
74
75 fn base_url(&self) -> &str {
76 self.provider.base_url()
77 }
78
79 fn get_or_build_client(&mut self) -> Result<reqwest::Client> {
80 if let Some(ref client) = self.client {
81 return Ok(client.clone());
82 }
83 let client = reqwest::Client::builder()
84 .timeout(self.timeout())
85 .build()
86 .map_err(|e| AgentError::Other {
87 message: "failed to build HTTP client".into(),
88 source: Some(Box::new(e)),
89 })?;
90 self.client = Some(client.clone());
91 Ok(client)
92 }
93
94 fn build_request(
95 &self,
96 input: &Input,
97 options: &TurnOptions,
98 history: &[WireMessage],
99 ) -> ChatCompletionRequest {
100 let mut messages = Vec::new();
101
102 if let Some(sp) = &self.config.system_prompt {
103 messages.push(WireMessage::system(sp.clone()));
104 }
105
106 messages.extend_from_slice(history);
107
108 let prompt_text = prompt_text_from_input(input);
109 messages.push(WireMessage::user(prompt_text));
110
111 let response_format = options.output_schema.as_ref().map(|schema| ResponseFormat {
112 format_type: "json_schema".into(),
113 json_schema: Some(schema.clone()),
114 });
115
116 let stream = if self.config.backend.stream {
117 Some(true)
118 } else {
119 None
120 };
121
122 ChatCompletionRequest {
123 model: self.model().to_string(),
124 messages,
125 max_tokens: self.config.backend.max_tokens,
126 temperature: self.config.backend.temperature,
127 stream,
128 response_format,
129 }
130 }
131}
132
133fn prompt_text_from_input(input: &Input) -> String {
134 match input {
135 Input::Text(s) => s.clone(),
136 Input::Structured(parts) => parts
137 .iter()
138 .filter_map(|p| match p {
139 InputPart::Text(t) => Some(t.as_str()),
140 _ => None,
141 })
142 .collect::<Vec<_>>()
143 .join("\n\n"),
144 _ => String::new(),
145 }
146}
147
148impl Session for OpenCodeSession {
149 fn turn_stream(
150 &mut self,
151 input: Input,
152 options: TurnOptions,
153 ) -> Pin<Box<dyn Future<Output = Result<EventStream>> + Send + '_>> {
154 Box::pin(async move {
155 if self.active_interrupt.lock().await.is_some() {
156 return Err(AgentError::Other {
157 message: "turn already in progress".into(),
158 source: None,
159 });
160 }
161
162 let client = self.get_or_build_client()?;
163 let history_snapshot = self.history.lock().await.clone();
164 let request = self.build_request(&input, &options, &history_snapshot);
165 let url = format!("{}/chat/completions", self.base_url());
166 let api_key = self.api_key.clone();
167 let max_retries = self.max_retries();
168 let is_streaming = request.stream == Some(true);
169
170 let prompt_text = prompt_text_from_input(&input);
172 self.history
173 .lock()
174 .await
175 .push(WireMessage::user(prompt_text));
176
177 let (tx, rx) =
178 tokio::sync::mpsc::channel::<std::result::Result<AgentEvent, AgentError>>(128);
179 let session_id = Arc::clone(&self.session_id);
180 let (interrupt_tx, interrupt_rx) = tokio::sync::oneshot::channel();
181 {
182 let mut slot = self.active_interrupt.lock().await;
183 *slot = Some(interrupt_tx);
184 }
185 let active_interrupt = Arc::clone(&self.active_interrupt);
186 let history = Arc::clone(&self.history);
187 let timeout_duration = options.timeout;
188
189 tokio::spawn(async move {
190 execute_turn(
191 client,
192 url,
193 api_key,
194 request,
195 max_retries,
196 is_streaming,
197 tx,
198 interrupt_rx,
199 session_id,
200 active_interrupt,
201 history,
202 timeout_duration,
203 )
204 .await;
205 });
206
207 Ok(EventStream::from_receiver(rx))
208 })
209 }
210
211 fn session_id(&self) -> Option<&str> {
212 self.session_id.get().map(|s| s.as_str())
213 }
214
215 fn interrupt(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
216 let active_interrupt = Arc::clone(&self.active_interrupt);
217 Box::pin(async move {
218 if let Some(tx) = active_interrupt.lock().await.take() {
219 let _ = tx.send(());
220 }
221 Ok(())
222 })
223 }
224}
225
226#[allow(clippy::too_many_arguments)]
227async fn execute_turn(
228 client: reqwest::Client,
229 url: String,
230 api_key: Option<String>,
231 request: ChatCompletionRequest,
232 max_retries: u32,
233 is_streaming: bool,
234 tx: tokio::sync::mpsc::Sender<std::result::Result<AgentEvent, AgentError>>,
235 mut interrupt_rx: tokio::sync::oneshot::Receiver<()>,
236 session_id: Arc<OnceLock<String>>,
237 active_interrupt: Arc<tokio::sync::Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
238 history: Arc<tokio::sync::Mutex<Vec<WireMessage>>>,
239 timeout_duration: Option<Duration>,
240) {
241 let timeout_sleep = tokio::time::sleep(timeout_duration.unwrap_or(Duration::MAX));
242 tokio::pin!(timeout_sleep);
243 let has_timeout = timeout_duration.is_some();
244
245 let _ = tx.send(Ok(AgentEvent::TurnStarted)).await;
246
247 let mut attempt = 0u32;
249 let response = loop {
250 let mut req = client.post(&url).json(&request);
251 if let Some(ref key) = api_key {
252 req = req.bearer_auth(key);
253 }
254
255 tokio::select! {
256 _ = &mut interrupt_rx => {
257 let _ = tx.send(Err(AgentError::Interrupted)).await;
258 active_interrupt.lock().await.take();
259 return;
260 }
261 _ = &mut timeout_sleep, if has_timeout => {
262 let _ = tx.send(Err(AgentError::Timeout {
263 duration: timeout_duration.unwrap(),
264 })).await;
265 active_interrupt.lock().await.take();
266 return;
267 }
268 result = req.send() => {
269 match result {
270 Ok(resp) if resp.status().is_success() => break resp,
271 Ok(resp) => {
272 let status = resp.status();
273 let retry_after = resp.headers()
274 .get("retry-after")
275 .and_then(|v| v.to_str().ok())
276 .and_then(|s| s.parse::<u64>().ok())
277 .unwrap_or(60);
278 let body = resp.text().await.unwrap_or_default();
279
280 if attempt >= max_retries {
281 let _ = tx.send(Ok(AgentEvent::TurnFailed {
282 message: format!("HTTP {status}: {body}"),
283 })).await;
284 active_interrupt.lock().await.take();
285 return;
286 }
287
288 let delay = if status.as_u16() == 429 {
289 Duration::from_secs(retry_after.min(MAX_RETRY_DELAY.as_secs()))
290 } else if status.is_server_error() {
291 (BASE_RETRY_DELAY * (1 << attempt)).min(MAX_RETRY_DELAY)
292 } else {
293 let _ = tx.send(Ok(AgentEvent::TurnFailed {
295 message: format!("HTTP {status}: {body}"),
296 })).await;
297 active_interrupt.lock().await.take();
298 return;
299 };
300
301 attempt += 1;
302 tokio::time::sleep(delay).await;
303 continue;
304 }
305 Err(e) => {
306 if attempt >= max_retries {
307 let _ = tx.send(Err(AgentError::Other {
308 message: format!("HTTP request failed: {e}"),
309 source: Some(Box::new(e)),
310 })).await;
311 active_interrupt.lock().await.take();
312 return;
313 }
314 attempt += 1;
315 let delay = (BASE_RETRY_DELAY * (1 << attempt)).min(MAX_RETRY_DELAY);
316 tokio::time::sleep(delay).await;
317 continue;
318 }
319 }
320 }
321 }
322 };
323
324 if is_streaming {
325 handle_sse_stream(
326 response,
327 &tx,
328 interrupt_rx,
329 &session_id,
330 &history,
331 timeout_duration,
332 )
333 .await;
334 } else {
335 handle_non_streaming(response, &tx, &session_id, &history).await;
336 }
337
338 active_interrupt.lock().await.take();
339}
340
341async fn handle_non_streaming(
342 response: reqwest::Response,
343 tx: &tokio::sync::mpsc::Sender<std::result::Result<AgentEvent, AgentError>>,
344 session_id: &Arc<OnceLock<String>>,
345 history: &Arc<tokio::sync::Mutex<Vec<WireMessage>>>,
346) {
347 match response.json::<ChatCompletionResponse>().await {
348 Ok(resp) => {
349 if let Some(choice) = resp.choices.first() {
351 history
352 .lock()
353 .await
354 .push(WireMessage::assistant(&choice.message.content));
355 }
356
357 let events = translate::translate_response(resp);
358 for event in events {
359 if let AgentEvent::TurnCompleted {
360 result: Some(ref r),
361 ..
362 } = event
363 && let Some(ref sid) = r.session_id
364 {
365 let _ = session_id.set(sid.clone());
366 }
367 if tx.send(Ok(event)).await.is_err() {
368 return;
369 }
370 }
371 }
372 Err(e) => {
373 let _ = tx
374 .send(Err(AgentError::Other {
375 message: format!("failed to parse response: {e}"),
376 source: Some(Box::new(e)),
377 }))
378 .await;
379 }
380 }
381}
382
383#[allow(clippy::too_many_arguments)]
384async fn handle_sse_stream(
385 response: reqwest::Response,
386 tx: &tokio::sync::mpsc::Sender<std::result::Result<AgentEvent, AgentError>>,
387 mut interrupt_rx: tokio::sync::oneshot::Receiver<()>,
388 session_id: &Arc<OnceLock<String>>,
389 history: &Arc<tokio::sync::Mutex<Vec<WireMessage>>>,
390 timeout_duration: Option<Duration>,
391) {
392 let timeout_sleep = tokio::time::sleep(timeout_duration.unwrap_or(Duration::MAX));
393 tokio::pin!(timeout_sleep);
394 let has_timeout = timeout_duration.is_some();
395 let mut buf = String::new();
396 let mut full_content = String::new();
397 let mut response = response;
398
399 loop {
400 tokio::select! {
401 _ = &mut interrupt_rx => {
402 let _ = tx.send(Err(AgentError::Interrupted)).await;
403 return;
404 }
405 _ = &mut timeout_sleep, if has_timeout => {
406 let _ = tx.send(Err(AgentError::Timeout {
407 duration: timeout_duration.unwrap(),
408 })).await;
409 return;
410 }
411 chunk_result = response.chunk() => {
412 match chunk_result {
413 Ok(Some(bytes)) => {
414 buf.push_str(&String::from_utf8_lossy(&bytes));
415
416 while let Some(pos) = buf.find('\n') {
417 let line = buf[..pos].to_string();
418 buf = buf[pos + 1..].to_string();
419
420 let line = line.trim();
421 if line.is_empty() {
422 continue;
423 }
424
425 if let Some(data) = line.strip_prefix("data: ") {
426 if data == "[DONE]" {
427 if !full_content.is_empty() {
429 history.lock().await.push(
430 WireMessage::assistant(&full_content),
431 );
432 }
433 return;
434 }
435
436 if let Ok(chunk) = serde_json::from_str::<ChatCompletionChunk>(data) {
437 for choice in &chunk.choices {
439 if let Some(ref c) = choice.delta.content {
440 full_content.push_str(c);
441 }
442 }
443
444 let events = translate::translate_chunk(chunk);
445 for event in events {
446 if let AgentEvent::TurnCompleted {
447 result: Some(ref r), ..
448 } = event
449 && let Some(ref sid) = r.session_id {
450 let _ = session_id.set(sid.clone());
451 }
452 if tx.send(Ok(event)).await.is_err() {
453 return;
454 }
455 }
456 }
457 }
458 }
459 }
460 Ok(None) => {
461 if !full_content.is_empty() {
463 history.lock().await.push(
464 WireMessage::assistant(&full_content),
465 );
466 }
467 return;
468 }
469 Err(e) => {
470 let _ = tx.send(Err(AgentError::Other {
471 message: format!("SSE stream error: {e}"),
472 source: Some(Box::new(e)),
473 })).await;
474 return;
475 }
476 }
477 }
478 }
479 }
480}
481
482#[cfg(test)]
483mod tests {
484 use super::*;
485
486 fn make_session(config: SessionConfig<OpenCodeConfig>) -> OpenCodeSession {
487 OpenCodeSession::new(Some("test-key".into()), OpenCodeProvider::Zen, config)
488 }
489
490 fn default_config() -> SessionConfig<OpenCodeConfig> {
491 SessionConfig::builder()
492 .backend(OpenCodeConfig::default())
493 .build()
494 }
495
496 #[test]
497 fn session_id_none_initially() {
498 let session = make_session(default_config());
499 assert!(session.session_id().is_none());
500 }
501
502 #[test]
503 fn session_id_after_set() {
504 let session = make_session(default_config()).with_session_id("sess-1".into());
505 assert_eq!(session.session_id(), Some("sess-1"));
506 }
507
508 #[test]
509 fn model_defaults_to_glm5() {
510 let session = make_session(default_config());
511 assert_eq!(session.model(), "glm-5");
512 }
513
514 #[test]
515 fn model_from_config() {
516 let config = SessionConfig::builder()
517 .model("claude-sonnet-4-20250514")
518 .backend(OpenCodeConfig::default())
519 .build();
520 let session = make_session(config);
521 assert_eq!(session.model(), "claude-sonnet-4-20250514");
522 }
523
524 #[tokio::test]
525 async fn build_request_basic() {
526 let session = make_session(default_config());
527 let history = session.history.lock().await.clone();
528 let req = session.build_request(
529 &Input::Text("hello".into()),
530 &TurnOptions::default(),
531 &history,
532 );
533
534 assert_eq!(req.model, "glm-5");
535 assert_eq!(req.messages.len(), 1);
536 assert_eq!(req.messages[0].role, "user");
537 assert_eq!(req.messages[0].content, "hello");
538 assert_eq!(req.stream, Some(true));
539 assert!(req.response_format.is_none());
540 }
541
542 #[tokio::test]
543 async fn build_request_with_system_prompt_and_schema() {
544 let config = SessionConfig::builder()
545 .system_prompt("Be helpful")
546 .backend(OpenCodeConfig::builder().stream(false).build())
547 .build();
548 let session = make_session(config);
549 let options = TurnOptions {
550 output_schema: Some(serde_json::json!({"type": "object"})),
551 ..Default::default()
552 };
553 let history = session.history.lock().await.clone();
554 let req = session.build_request(&Input::Text("test".into()), &options, &history);
555
556 assert_eq!(req.messages[0].role, "system");
557 assert_eq!(req.messages[0].content, "Be helpful");
558 assert_eq!(req.messages[1].role, "user");
559 assert!(req.stream.is_none());
560 assert!(req.response_format.is_some());
561 assert_eq!(
562 req.response_format.as_ref().unwrap().format_type,
563 "json_schema"
564 );
565 }
566
567 #[tokio::test]
568 async fn build_request_includes_history() {
569 let config = default_config();
570 let session = make_session(config);
571 {
572 let mut h = session.history.lock().await;
573 h.push(WireMessage::user("first"));
574 h.push(WireMessage::assistant("reply"));
575 }
576 let history = session.history.lock().await.clone();
577 let req = session.build_request(
578 &Input::Text("second".into()),
579 &TurnOptions::default(),
580 &history,
581 );
582
583 assert_eq!(req.messages.len(), 3);
584 assert_eq!(req.messages[0].content, "first");
585 assert_eq!(req.messages[1].content, "reply");
586 assert_eq!(req.messages[2].content, "second");
587 }
588
589 #[tokio::test]
590 async fn turn_stream_rejects_when_turn_active() {
591 let mut session = make_session(default_config());
592 let (tx, _rx) = tokio::sync::oneshot::channel();
593 *session.active_interrupt.lock().await = Some(tx);
594
595 let err = session
596 .turn_stream(Input::Text("hi".into()), TurnOptions::default())
597 .await;
598 match err {
599 Err(AgentError::Other { message, .. }) => {
600 assert_eq!(message, "turn already in progress")
601 }
602 _ => panic!("expected Other error"),
603 }
604 }
605}