Skip to main content

chimera_opencode/
session.rs

1use std::future::Future;
2use std::pin::Pin;
3use std::sync::{Arc, OnceLock};
4use std::time::Duration;
5
6use chimera_core::*;
7
8use crate::config::{OpenCodeConfig, OpenCodeProvider};
9use crate::translate;
10use crate::wire::*;
11
12const DEFAULT_TIMEOUT_SECS: u64 = 300;
13const DEFAULT_MAX_RETRIES: u32 = 2;
14const DEFAULT_MODEL: &str = "glm-5";
15const BASE_RETRY_DELAY: Duration = Duration::from_secs(5);
16const MAX_RETRY_DELAY: Duration = Duration::from_secs(120);
17
18pub struct OpenCodeSession {
19    api_key: Option<String>,
20    provider: OpenCodeProvider,
21    session_id: Arc<OnceLock<String>>,
22    active_interrupt: Arc<tokio::sync::Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
23    config: SessionConfig<OpenCodeConfig>,
24    history: Arc<tokio::sync::Mutex<Vec<WireMessage>>>,
25    client: Option<reqwest::Client>,
26}
27
28impl OpenCodeSession {
29    pub(crate) fn new(
30        api_key: Option<String>,
31        provider: OpenCodeProvider,
32        config: SessionConfig<OpenCodeConfig>,
33    ) -> Self {
34        Self {
35            api_key,
36            provider,
37            session_id: Arc::new(OnceLock::new()),
38            active_interrupt: Arc::new(tokio::sync::Mutex::new(None)),
39            config,
40            history: Arc::new(tokio::sync::Mutex::new(Vec::new())),
41            client: None,
42        }
43    }
44
45    pub(crate) fn with_session_id(self, id: String) -> Self {
46        let _ = self.session_id.set(id);
47        self
48    }
49
50    #[cfg(test)]
51    pub(crate) fn api_key(&self) -> Option<&str> {
52        self.api_key.as_deref()
53    }
54
55    fn model(&self) -> &str {
56        self.config.model.as_deref().unwrap_or(DEFAULT_MODEL)
57    }
58
59    fn timeout(&self) -> Duration {
60        Duration::from_secs(
61            self.config
62                .backend
63                .timeout_secs
64                .unwrap_or(DEFAULT_TIMEOUT_SECS),
65        )
66    }
67
68    fn max_retries(&self) -> u32 {
69        self.config
70            .backend
71            .max_retries
72            .unwrap_or(DEFAULT_MAX_RETRIES)
73    }
74
75    fn base_url(&self) -> &str {
76        self.provider.base_url()
77    }
78
79    fn get_or_build_client(&mut self) -> Result<reqwest::Client> {
80        if let Some(ref client) = self.client {
81            return Ok(client.clone());
82        }
83        let client = reqwest::Client::builder()
84            .timeout(self.timeout())
85            .build()
86            .map_err(|e| AgentError::Other {
87                message: "failed to build HTTP client".into(),
88                source: Some(Box::new(e)),
89            })?;
90        self.client = Some(client.clone());
91        Ok(client)
92    }
93
94    fn build_request(
95        &self,
96        input: &Input,
97        options: &TurnOptions,
98        history: &[WireMessage],
99    ) -> ChatCompletionRequest {
100        let mut messages = Vec::new();
101
102        if let Some(sp) = &self.config.system_prompt {
103            messages.push(WireMessage::system(sp.clone()));
104        }
105
106        messages.extend_from_slice(history);
107
108        let prompt_text = prompt_text_from_input(input);
109        messages.push(WireMessage::user(prompt_text));
110
111        let response_format = options.output_schema.as_ref().map(|schema| ResponseFormat {
112            format_type: "json_schema".into(),
113            json_schema: Some(schema.clone()),
114        });
115
116        let stream = if self.config.backend.stream {
117            Some(true)
118        } else {
119            None
120        };
121
122        ChatCompletionRequest {
123            model: self.model().to_string(),
124            messages,
125            max_tokens: self.config.backend.max_tokens,
126            temperature: self.config.backend.temperature,
127            stream,
128            response_format,
129        }
130    }
131}
132
133fn prompt_text_from_input(input: &Input) -> String {
134    match input {
135        Input::Text(s) => s.clone(),
136        Input::Structured(parts) => parts
137            .iter()
138            .filter_map(|p| match p {
139                InputPart::Text(t) => Some(t.as_str()),
140                _ => None,
141            })
142            .collect::<Vec<_>>()
143            .join("\n\n"),
144        _ => String::new(),
145    }
146}
147
148impl Session for OpenCodeSession {
149    fn turn_stream(
150        &mut self,
151        input: Input,
152        options: TurnOptions,
153    ) -> Pin<Box<dyn Future<Output = Result<EventStream>> + Send + '_>> {
154        Box::pin(async move {
155            if self.active_interrupt.lock().await.is_some() {
156                return Err(AgentError::Other {
157                    message: "turn already in progress".into(),
158                    source: None,
159                });
160            }
161
162            let client = self.get_or_build_client()?;
163            let history_snapshot = self.history.lock().await.clone();
164            let request = self.build_request(&input, &options, &history_snapshot);
165            let url = format!("{}/chat/completions", self.base_url());
166            let api_key = self.api_key.clone();
167            let max_retries = self.max_retries();
168            let is_streaming = request.stream == Some(true);
169
170            // Append user message to history for multi-turn.
171            let prompt_text = prompt_text_from_input(&input);
172            self.history
173                .lock()
174                .await
175                .push(WireMessage::user(prompt_text));
176
177            let (tx, rx) =
178                tokio::sync::mpsc::channel::<std::result::Result<AgentEvent, AgentError>>(128);
179            let session_id = Arc::clone(&self.session_id);
180            let (interrupt_tx, interrupt_rx) = tokio::sync::oneshot::channel();
181            {
182                let mut slot = self.active_interrupt.lock().await;
183                *slot = Some(interrupt_tx);
184            }
185            let active_interrupt = Arc::clone(&self.active_interrupt);
186            let history = Arc::clone(&self.history);
187            let timeout_duration = options.timeout;
188
189            tokio::spawn(async move {
190                execute_turn(
191                    client,
192                    url,
193                    api_key,
194                    request,
195                    max_retries,
196                    is_streaming,
197                    tx,
198                    interrupt_rx,
199                    session_id,
200                    active_interrupt,
201                    history,
202                    timeout_duration,
203                )
204                .await;
205            });
206
207            Ok(EventStream::from_receiver(rx))
208        })
209    }
210
211    fn session_id(&self) -> Option<&str> {
212        self.session_id.get().map(|s| s.as_str())
213    }
214
215    fn interrupt(&mut self) -> Pin<Box<dyn Future<Output = Result<()>> + Send + '_>> {
216        let active_interrupt = Arc::clone(&self.active_interrupt);
217        Box::pin(async move {
218            if let Some(tx) = active_interrupt.lock().await.take() {
219                let _ = tx.send(());
220            }
221            Ok(())
222        })
223    }
224}
225
226#[allow(clippy::too_many_arguments)]
227async fn execute_turn(
228    client: reqwest::Client,
229    url: String,
230    api_key: Option<String>,
231    request: ChatCompletionRequest,
232    max_retries: u32,
233    is_streaming: bool,
234    tx: tokio::sync::mpsc::Sender<std::result::Result<AgentEvent, AgentError>>,
235    mut interrupt_rx: tokio::sync::oneshot::Receiver<()>,
236    session_id: Arc<OnceLock<String>>,
237    active_interrupt: Arc<tokio::sync::Mutex<Option<tokio::sync::oneshot::Sender<()>>>>,
238    history: Arc<tokio::sync::Mutex<Vec<WireMessage>>>,
239    timeout_duration: Option<Duration>,
240) {
241    let timeout_sleep = tokio::time::sleep(timeout_duration.unwrap_or(Duration::MAX));
242    tokio::pin!(timeout_sleep);
243    let has_timeout = timeout_duration.is_some();
244
245    let _ = tx.send(Ok(AgentEvent::TurnStarted)).await;
246
247    // Retry loop for the HTTP request.
248    let mut attempt = 0u32;
249    let response = loop {
250        let mut req = client.post(&url).json(&request);
251        if let Some(ref key) = api_key {
252            req = req.bearer_auth(key);
253        }
254
255        tokio::select! {
256            _ = &mut interrupt_rx => {
257                let _ = tx.send(Err(AgentError::Interrupted)).await;
258                active_interrupt.lock().await.take();
259                return;
260            }
261            _ = &mut timeout_sleep, if has_timeout => {
262                let _ = tx.send(Err(AgentError::Timeout {
263                    duration: timeout_duration.unwrap(),
264                })).await;
265                active_interrupt.lock().await.take();
266                return;
267            }
268            result = req.send() => {
269                match result {
270                    Ok(resp) if resp.status().is_success() => break resp,
271                    Ok(resp) => {
272                        let status = resp.status();
273                        let retry_after = resp.headers()
274                            .get("retry-after")
275                            .and_then(|v| v.to_str().ok())
276                            .and_then(|s| s.parse::<u64>().ok())
277                            .unwrap_or(60);
278                        let body = resp.text().await.unwrap_or_default();
279
280                        if attempt >= max_retries {
281                            let _ = tx.send(Ok(AgentEvent::TurnFailed {
282                                message: format!("HTTP {status}: {body}"),
283                            })).await;
284                            active_interrupt.lock().await.take();
285                            return;
286                        }
287
288                        let delay = if status.as_u16() == 429 {
289                            Duration::from_secs(retry_after.min(MAX_RETRY_DELAY.as_secs()))
290                        } else if status.is_server_error() {
291                            (BASE_RETRY_DELAY * (1 << attempt)).min(MAX_RETRY_DELAY)
292                        } else {
293                            // Client error (4xx other than 429) -- don't retry.
294                            let _ = tx.send(Ok(AgentEvent::TurnFailed {
295                                message: format!("HTTP {status}: {body}"),
296                            })).await;
297                            active_interrupt.lock().await.take();
298                            return;
299                        };
300
301                        attempt += 1;
302                        tokio::time::sleep(delay).await;
303                        continue;
304                    }
305                    Err(e) => {
306                        if attempt >= max_retries {
307                            let _ = tx.send(Err(AgentError::Other {
308                                message: format!("HTTP request failed: {e}"),
309                                source: Some(Box::new(e)),
310                            })).await;
311                            active_interrupt.lock().await.take();
312                            return;
313                        }
314                        attempt += 1;
315                        let delay = (BASE_RETRY_DELAY * (1 << attempt)).min(MAX_RETRY_DELAY);
316                        tokio::time::sleep(delay).await;
317                        continue;
318                    }
319                }
320            }
321        }
322    };
323
324    if is_streaming {
325        handle_sse_stream(
326            response,
327            &tx,
328            interrupt_rx,
329            &session_id,
330            &history,
331            timeout_duration,
332        )
333        .await;
334    } else {
335        handle_non_streaming(response, &tx, &session_id, &history).await;
336    }
337
338    active_interrupt.lock().await.take();
339}
340
341async fn handle_non_streaming(
342    response: reqwest::Response,
343    tx: &tokio::sync::mpsc::Sender<std::result::Result<AgentEvent, AgentError>>,
344    session_id: &Arc<OnceLock<String>>,
345    history: &Arc<tokio::sync::Mutex<Vec<WireMessage>>>,
346) {
347    match response.json::<ChatCompletionResponse>().await {
348        Ok(resp) => {
349            // Append assistant response to history for multi-turn.
350            if let Some(choice) = resp.choices.first() {
351                history
352                    .lock()
353                    .await
354                    .push(WireMessage::assistant(&choice.message.content));
355            }
356
357            let events = translate::translate_response(resp);
358            for event in events {
359                if let AgentEvent::TurnCompleted {
360                    result: Some(ref r),
361                    ..
362                } = event
363                    && let Some(ref sid) = r.session_id
364                {
365                    let _ = session_id.set(sid.clone());
366                }
367                if tx.send(Ok(event)).await.is_err() {
368                    return;
369                }
370            }
371        }
372        Err(e) => {
373            let _ = tx
374                .send(Err(AgentError::Other {
375                    message: format!("failed to parse response: {e}"),
376                    source: Some(Box::new(e)),
377                }))
378                .await;
379        }
380    }
381}
382
383#[allow(clippy::too_many_arguments)]
384async fn handle_sse_stream(
385    response: reqwest::Response,
386    tx: &tokio::sync::mpsc::Sender<std::result::Result<AgentEvent, AgentError>>,
387    mut interrupt_rx: tokio::sync::oneshot::Receiver<()>,
388    session_id: &Arc<OnceLock<String>>,
389    history: &Arc<tokio::sync::Mutex<Vec<WireMessage>>>,
390    timeout_duration: Option<Duration>,
391) {
392    let timeout_sleep = tokio::time::sleep(timeout_duration.unwrap_or(Duration::MAX));
393    tokio::pin!(timeout_sleep);
394    let has_timeout = timeout_duration.is_some();
395    let mut buf = String::new();
396    let mut full_content = String::new();
397    let mut response = response;
398
399    loop {
400        tokio::select! {
401            _ = &mut interrupt_rx => {
402                let _ = tx.send(Err(AgentError::Interrupted)).await;
403                return;
404            }
405            _ = &mut timeout_sleep, if has_timeout => {
406                let _ = tx.send(Err(AgentError::Timeout {
407                    duration: timeout_duration.unwrap(),
408                })).await;
409                return;
410            }
411            chunk_result = response.chunk() => {
412                match chunk_result {
413                    Ok(Some(bytes)) => {
414                        buf.push_str(&String::from_utf8_lossy(&bytes));
415
416                        while let Some(pos) = buf.find('\n') {
417                            let line = buf[..pos].to_string();
418                            buf = buf[pos + 1..].to_string();
419
420                            let line = line.trim();
421                            if line.is_empty() {
422                                continue;
423                            }
424
425                            if let Some(data) = line.strip_prefix("data: ") {
426                                if data == "[DONE]" {
427                                    // Append accumulated content to history.
428                                    if !full_content.is_empty() {
429                                        history.lock().await.push(
430                                            WireMessage::assistant(&full_content),
431                                        );
432                                    }
433                                    return;
434                                }
435
436                                if let Ok(chunk) = serde_json::from_str::<ChatCompletionChunk>(data) {
437                                    // Accumulate content for history.
438                                    for choice in &chunk.choices {
439                                        if let Some(ref c) = choice.delta.content {
440                                            full_content.push_str(c);
441                                        }
442                                    }
443
444                                    let events = translate::translate_chunk(chunk);
445                                    for event in events {
446                                        if let AgentEvent::TurnCompleted {
447                                            result: Some(ref r), ..
448                                        } = event
449                                            && let Some(ref sid) = r.session_id {
450                                                let _ = session_id.set(sid.clone());
451                                            }
452                                        if tx.send(Ok(event)).await.is_err() {
453                                            return;
454                                        }
455                                    }
456                                }
457                            }
458                        }
459                    }
460                    Ok(None) => {
461                        // Stream ended without [DONE]. Append any accumulated content.
462                        if !full_content.is_empty() {
463                            history.lock().await.push(
464                                WireMessage::assistant(&full_content),
465                            );
466                        }
467                        return;
468                    }
469                    Err(e) => {
470                        let _ = tx.send(Err(AgentError::Other {
471                            message: format!("SSE stream error: {e}"),
472                            source: Some(Box::new(e)),
473                        })).await;
474                        return;
475                    }
476                }
477            }
478        }
479    }
480}
481
482#[cfg(test)]
483mod tests {
484    use super::*;
485
486    fn make_session(config: SessionConfig<OpenCodeConfig>) -> OpenCodeSession {
487        OpenCodeSession::new(Some("test-key".into()), OpenCodeProvider::Zen, config)
488    }
489
490    fn default_config() -> SessionConfig<OpenCodeConfig> {
491        SessionConfig::builder()
492            .backend(OpenCodeConfig::default())
493            .build()
494    }
495
496    #[test]
497    fn session_id_none_initially() {
498        let session = make_session(default_config());
499        assert!(session.session_id().is_none());
500    }
501
502    #[test]
503    fn session_id_after_set() {
504        let session = make_session(default_config()).with_session_id("sess-1".into());
505        assert_eq!(session.session_id(), Some("sess-1"));
506    }
507
508    #[test]
509    fn model_defaults_to_glm5() {
510        let session = make_session(default_config());
511        assert_eq!(session.model(), "glm-5");
512    }
513
514    #[test]
515    fn model_from_config() {
516        let config = SessionConfig::builder()
517            .model("claude-sonnet-4-20250514")
518            .backend(OpenCodeConfig::default())
519            .build();
520        let session = make_session(config);
521        assert_eq!(session.model(), "claude-sonnet-4-20250514");
522    }
523
524    #[tokio::test]
525    async fn build_request_basic() {
526        let session = make_session(default_config());
527        let history = session.history.lock().await.clone();
528        let req = session.build_request(
529            &Input::Text("hello".into()),
530            &TurnOptions::default(),
531            &history,
532        );
533
534        assert_eq!(req.model, "glm-5");
535        assert_eq!(req.messages.len(), 1);
536        assert_eq!(req.messages[0].role, "user");
537        assert_eq!(req.messages[0].content, "hello");
538        assert_eq!(req.stream, Some(true));
539        assert!(req.response_format.is_none());
540    }
541
542    #[tokio::test]
543    async fn build_request_with_system_prompt_and_schema() {
544        let config = SessionConfig::builder()
545            .system_prompt("Be helpful")
546            .backend(OpenCodeConfig::builder().stream(false).build())
547            .build();
548        let session = make_session(config);
549        let options = TurnOptions {
550            output_schema: Some(serde_json::json!({"type": "object"})),
551            ..Default::default()
552        };
553        let history = session.history.lock().await.clone();
554        let req = session.build_request(&Input::Text("test".into()), &options, &history);
555
556        assert_eq!(req.messages[0].role, "system");
557        assert_eq!(req.messages[0].content, "Be helpful");
558        assert_eq!(req.messages[1].role, "user");
559        assert!(req.stream.is_none());
560        assert!(req.response_format.is_some());
561        assert_eq!(
562            req.response_format.as_ref().unwrap().format_type,
563            "json_schema"
564        );
565    }
566
567    #[tokio::test]
568    async fn build_request_includes_history() {
569        let config = default_config();
570        let session = make_session(config);
571        {
572            let mut h = session.history.lock().await;
573            h.push(WireMessage::user("first"));
574            h.push(WireMessage::assistant("reply"));
575        }
576        let history = session.history.lock().await.clone();
577        let req = session.build_request(
578            &Input::Text("second".into()),
579            &TurnOptions::default(),
580            &history,
581        );
582
583        assert_eq!(req.messages.len(), 3);
584        assert_eq!(req.messages[0].content, "first");
585        assert_eq!(req.messages[1].content, "reply");
586        assert_eq!(req.messages[2].content, "second");
587    }
588
589    #[tokio::test]
590    async fn turn_stream_rejects_when_turn_active() {
591        let mut session = make_session(default_config());
592        let (tx, _rx) = tokio::sync::oneshot::channel();
593        *session.active_interrupt.lock().await = Some(tx);
594
595        let err = session
596            .turn_stream(Input::Text("hi".into()), TurnOptions::default())
597            .await;
598        match err {
599            Err(AgentError::Other { message, .. }) => {
600                assert_eq!(message, "turn already in progress")
601            }
602            _ => panic!("expected Other error"),
603        }
604    }
605}