Skip to main content

pi/providers/
cohere.rs

1//! Cohere Chat API provider implementation.
2//!
3//! This module implements the Provider trait for Cohere's `v2/chat` endpoint,
4//! supporting streaming output text/thinking and function tool calls.
5
6use crate::error::{Error, Result};
7use crate::http::client::Client;
8use crate::model::{
9    AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, TextContent, ThinkingContent,
10    ToolCall, Usage, UserContent,
11};
12use crate::models::CompatConfig;
13use crate::provider::{Context, Provider, StreamOptions, ToolDef};
14use crate::sse::SseStream;
15use async_trait::async_trait;
16use futures::StreamExt;
17use futures::stream::{self, Stream};
18use serde::{Deserialize, Serialize};
19use std::collections::{HashMap, VecDeque};
20use std::pin::Pin;
21
22// ============================================================================
23// Constants
24// ============================================================================
25
26const COHERE_CHAT_API_URL: &str = "https://api.cohere.com/v2/chat";
27const DEFAULT_MAX_TOKENS: u32 = 4096;
28
29// ============================================================================
30// Cohere Provider
31// ============================================================================
32
33/// Cohere `v2/chat` streaming provider.
34pub struct CohereProvider {
35    client: Client,
36    model: String,
37    base_url: String,
38    provider: String,
39    compat: Option<CompatConfig>,
40}
41
42impl CohereProvider {
43    pub fn new(model: impl Into<String>) -> Self {
44        Self {
45            client: Client::new(),
46            model: model.into(),
47            base_url: COHERE_CHAT_API_URL.to_string(),
48            provider: "cohere".to_string(),
49            compat: None,
50        }
51    }
52
53    #[must_use]
54    pub fn with_provider_name(mut self, provider: impl Into<String>) -> Self {
55        self.provider = provider.into();
56        self
57    }
58
59    #[must_use]
60    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
61        self.base_url = base_url.into();
62        self
63    }
64
65    #[must_use]
66    pub fn with_client(mut self, client: Client) -> Self {
67        self.client = client;
68        self
69    }
70
71    /// Attach provider-specific compatibility overrides.
72    #[must_use]
73    pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
74        self.compat = compat;
75        self
76    }
77
78    pub fn build_request(&self, context: &Context<'_>, options: &StreamOptions) -> CohereRequest {
79        let messages = build_cohere_messages(context);
80
81        let tools: Option<Vec<CohereTool>> = if context.tools.is_empty() {
82            None
83        } else {
84            Some(context.tools.iter().map(convert_tool_to_cohere).collect())
85        };
86
87        CohereRequest {
88            model: self.model.clone(),
89            messages,
90            max_tokens: options.max_tokens.or(Some(DEFAULT_MAX_TOKENS)),
91            temperature: options.temperature,
92            tools,
93            stream: true,
94        }
95    }
96}
97
98#[async_trait]
99#[allow(clippy::too_many_lines)]
100impl Provider for CohereProvider {
101    fn name(&self) -> &str {
102        &self.provider
103    }
104
105    fn api(&self) -> &'static str {
106        "cohere-chat"
107    }
108
109    fn model_id(&self) -> &str {
110        &self.model
111    }
112
113    async fn stream(
114        &self,
115        context: &Context<'_>,
116        options: &StreamOptions,
117    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
118        let has_authorization_header = options
119            .headers
120            .keys()
121            .any(|key| key.eq_ignore_ascii_case("authorization"));
122
123        let auth_value = if has_authorization_header {
124            None
125        } else {
126            Some(
127                options
128                    .api_key
129                    .clone()
130                    .or_else(|| std::env::var("COHERE_API_KEY").ok())
131                    .ok_or_else(|| Error::provider("cohere", "Missing API key for Cohere. Set COHERE_API_KEY or configure in settings."))?,
132            )
133        };
134
135        let request_body = self.build_request(context, options);
136
137        // Content-Type set by .json() below
138        let mut request = self
139            .client
140            .post(&self.base_url)
141            .header("Accept", "text/event-stream");
142
143        if let Some(auth_value) = auth_value {
144            request = request.header("Authorization", format!("Bearer {auth_value}"));
145        }
146
147        // Apply provider-specific custom headers from compat config.
148        if let Some(compat) = &self.compat {
149            if let Some(custom_headers) = &compat.custom_headers {
150                for (key, value) in custom_headers {
151                    request = request.header(key, value);
152                }
153            }
154        }
155
156        // Per-request headers from StreamOptions (highest priority).
157        for (key, value) in &options.headers {
158            request = request.header(key, value);
159        }
160
161        let request = request.json(&request_body)?;
162
163        let response = Box::pin(request.send()).await?;
164        let status = response.status();
165        if !(200..300).contains(&status) {
166            let body = response
167                .text()
168                .await
169                .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
170            return Err(Error::provider(
171                "cohere",
172                format!("Cohere API error (HTTP {status}): {body}"),
173            ));
174        }
175
176        let content_type = response
177            .headers()
178            .iter()
179            .find(|(name, _)| name.eq_ignore_ascii_case("content-type"))
180            .map(|(_, value)| value.to_ascii_lowercase());
181        if !content_type
182            .as_deref()
183            .is_some_and(|value| value.contains("text/event-stream"))
184        {
185            let message = content_type.map_or_else(
186                || {
187                    format!(
188                        "Cohere API protocol error (HTTP {status}): missing Content-Type (expected text/event-stream)"
189                    )
190                },
191                |value| {
192                    format!(
193                        "Cohere API protocol error (HTTP {status}): unexpected Content-Type {value} (expected text/event-stream)"
194                    )
195                },
196            );
197            return Err(Error::api(message));
198        }
199
200        let event_source = SseStream::new(response.bytes_stream());
201
202        let model = self.model.clone();
203        let api = self.api().to_string();
204        let provider = self.name().to_string();
205
206        let stream = stream::unfold(
207            StreamState::new(event_source, model, api, provider),
208            |mut state| async move {
209                loop {
210                    if let Some(event) = state.pending_events.pop_front() {
211                        return Some((Ok(event), state));
212                    }
213
214                    if state.finished {
215                        return None;
216                    }
217
218                    match state.event_source.next().await {
219                        Some(Ok(msg)) => {
220                            if msg.data == "[DONE]" {
221                                state.finish();
222                                continue;
223                            }
224
225                            if let Err(e) = state.process_event(&msg.data) {
226                                return Some((Err(e), state));
227                            }
228                        }
229                        Some(Err(e)) => {
230                            let err = Error::api(format!("SSE error: {e}"));
231                            return Some((Err(err), state));
232                        }
233                        None => {
234                            // Stream ended without message-end; surface a consistent error.
235                            return Some((
236                                Err(Error::api("Stream ended without Done event")),
237                                state,
238                            ));
239                        }
240                    }
241                }
242            },
243        );
244
245        Ok(Box::pin(stream))
246    }
247}
248
249// ============================================================================
250// Stream State
251// ============================================================================
252
253struct ToolCallAccum {
254    content_index: usize,
255    id: String,
256    name: String,
257    arguments: String,
258}
259
260struct StreamState<S>
261where
262    S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
263{
264    event_source: SseStream<S>,
265    partial: AssistantMessage,
266    pending_events: VecDeque<StreamEvent>,
267    started: bool,
268    finished: bool,
269    content_index_map: HashMap<u32, usize>,
270    active_tool_call: Option<ToolCallAccum>,
271}
272
273impl<S> StreamState<S>
274where
275    S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
276{
277    fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
278        Self {
279            event_source,
280            partial: AssistantMessage {
281                content: Vec::new(),
282                api,
283                provider,
284                model,
285                usage: Usage::default(),
286                stop_reason: StopReason::Stop,
287                error_message: None,
288                timestamp: chrono::Utc::now().timestamp_millis(),
289            },
290            pending_events: VecDeque::new(),
291            started: false,
292            finished: false,
293            content_index_map: HashMap::new(),
294            active_tool_call: None,
295        }
296    }
297
298    fn ensure_started(&mut self) {
299        if !self.started {
300            self.started = true;
301            self.pending_events.push_back(StreamEvent::Start {
302                partial: self.partial.clone(),
303            });
304        }
305    }
306
307    fn content_block_for(&mut self, idx: u32, kind: CohereContentKind) -> usize {
308        if let Some(existing) = self.content_index_map.get(&idx) {
309            return *existing;
310        }
311
312        let content_index = self.partial.content.len();
313        match kind {
314            CohereContentKind::Text => {
315                self.partial
316                    .content
317                    .push(ContentBlock::Text(TextContent::new("")));
318                self.pending_events
319                    .push_back(StreamEvent::TextStart { content_index });
320            }
321            CohereContentKind::Thinking => {
322                self.partial
323                    .content
324                    .push(ContentBlock::Thinking(ThinkingContent {
325                        thinking: String::new(),
326                        thinking_signature: None,
327                    }));
328                self.pending_events
329                    .push_back(StreamEvent::ThinkingStart { content_index });
330            }
331        }
332
333        self.content_index_map.insert(idx, content_index);
334        content_index
335    }
336
337    #[allow(clippy::too_many_lines)]
338    fn process_event(&mut self, data: &str) -> Result<()> {
339        let chunk: CohereStreamChunk =
340            serde_json::from_str(data).map_err(|e| Error::api(format!("JSON parse error: {e}")))?;
341
342        match chunk {
343            CohereStreamChunk::MessageStart { .. } => {
344                self.ensure_started();
345            }
346            CohereStreamChunk::ContentStart { index, delta } => {
347                self.ensure_started();
348                let (kind, initial) = delta.message.content.kind_and_text();
349                let content_index = self.content_block_for(index, kind);
350
351                if !initial.is_empty() {
352                    match kind {
353                        CohereContentKind::Text => {
354                            if let Some(ContentBlock::Text(t)) =
355                                self.partial.content.get_mut(content_index)
356                            {
357                                t.text.push_str(&initial);
358                            }
359                            self.pending_events.push_back(StreamEvent::TextDelta {
360                                content_index,
361                                delta: initial,
362                            });
363                        }
364                        CohereContentKind::Thinking => {
365                            if let Some(ContentBlock::Thinking(t)) =
366                                self.partial.content.get_mut(content_index)
367                            {
368                                t.thinking.push_str(&initial);
369                            }
370                            self.pending_events.push_back(StreamEvent::ThinkingDelta {
371                                content_index,
372                                delta: initial,
373                            });
374                        }
375                    }
376                }
377            }
378            CohereStreamChunk::ContentDelta { index, delta } => {
379                self.ensure_started();
380                let (kind, delta_text) = delta.message.content.kind_and_text();
381                let content_index = self.content_block_for(index, kind);
382
383                match kind {
384                    CohereContentKind::Text => {
385                        if let Some(ContentBlock::Text(t)) =
386                            self.partial.content.get_mut(content_index)
387                        {
388                            t.text.push_str(&delta_text);
389                        }
390                        self.pending_events.push_back(StreamEvent::TextDelta {
391                            content_index,
392                            delta: delta_text,
393                        });
394                    }
395                    CohereContentKind::Thinking => {
396                        if let Some(ContentBlock::Thinking(t)) =
397                            self.partial.content.get_mut(content_index)
398                        {
399                            t.thinking.push_str(&delta_text);
400                        }
401                        self.pending_events.push_back(StreamEvent::ThinkingDelta {
402                            content_index,
403                            delta: delta_text,
404                        });
405                    }
406                }
407            }
408            CohereStreamChunk::ContentEnd { index } => {
409                if let Some(content_index) = self.content_index_map.get(&index).copied() {
410                    match self.partial.content.get(content_index) {
411                        Some(ContentBlock::Text(t)) => {
412                            self.pending_events.push_back(StreamEvent::TextEnd {
413                                content_index,
414                                content: t.text.clone(),
415                            });
416                        }
417                        Some(ContentBlock::Thinking(t)) => {
418                            self.pending_events.push_back(StreamEvent::ThinkingEnd {
419                                content_index,
420                                content: t.thinking.clone(),
421                            });
422                        }
423                        _ => {}
424                    }
425                }
426            }
427            CohereStreamChunk::ToolCallStart { delta } => {
428                self.ensure_started();
429                let tc = delta.message.tool_calls;
430                let content_index = self.partial.content.len();
431                self.partial.content.push(ContentBlock::ToolCall(ToolCall {
432                    id: tc.id.clone(),
433                    name: tc.function.name.clone(),
434                    arguments: serde_json::Value::Null,
435                    thought_signature: None,
436                }));
437
438                self.active_tool_call = Some(ToolCallAccum {
439                    content_index,
440                    id: tc.id,
441                    name: tc.function.name,
442                    arguments: tc.function.arguments.clone(),
443                });
444
445                self.pending_events
446                    .push_back(StreamEvent::ToolCallStart { content_index });
447                if !tc.function.arguments.is_empty() {
448                    self.pending_events.push_back(StreamEvent::ToolCallDelta {
449                        content_index,
450                        delta: tc.function.arguments,
451                    });
452                }
453            }
454            CohereStreamChunk::ToolCallDelta { delta } => {
455                self.ensure_started();
456                if let Some(active) = self.active_tool_call.as_mut() {
457                    active
458                        .arguments
459                        .push_str(&delta.message.tool_calls.function.arguments);
460                    self.pending_events.push_back(StreamEvent::ToolCallDelta {
461                        content_index: active.content_index,
462                        delta: delta.message.tool_calls.function.arguments,
463                    });
464                }
465            }
466            CohereStreamChunk::ToolCallEnd => {
467                if let Some(active) = self.active_tool_call.take() {
468                    self.ensure_started();
469                    let parsed_args: serde_json::Value = serde_json::from_str(&active.arguments)
470                        .unwrap_or_else(|e| {
471                            tracing::warn!(
472                                error = %e,
473                                raw = %active.arguments,
474                                "Failed to parse tool arguments as JSON"
475                            );
476                            serde_json::Value::Null
477                        });
478
479                    self.partial.stop_reason = StopReason::ToolUse;
480                    self.pending_events.push_back(StreamEvent::ToolCallEnd {
481                        content_index: active.content_index,
482                        tool_call: ToolCall {
483                            id: active.id,
484                            name: active.name,
485                            arguments: parsed_args.clone(),
486                            thought_signature: None,
487                        },
488                    });
489
490                    if let Some(ContentBlock::ToolCall(block)) =
491                        self.partial.content.get_mut(active.content_index)
492                    {
493                        block.arguments = parsed_args;
494                    }
495                }
496            }
497            CohereStreamChunk::MessageEnd { delta } => {
498                self.ensure_started();
499                self.partial.usage.input = delta.usage.tokens.input_tokens;
500                self.partial.usage.output = delta.usage.tokens.output_tokens;
501                self.partial.usage.total_tokens =
502                    delta.usage.tokens.input_tokens + delta.usage.tokens.output_tokens;
503
504                self.partial.stop_reason = match delta.finish_reason.as_str() {
505                    "MAX_TOKENS" => StopReason::Length,
506                    "TOOL_CALL" => StopReason::ToolUse,
507                    "ERROR" => StopReason::Error,
508                    _ => StopReason::Stop,
509                };
510
511                self.finish();
512            }
513            CohereStreamChunk::Unknown => {}
514        }
515
516        Ok(())
517    }
518
519    fn finish(&mut self) {
520        if self.finished {
521            return;
522        }
523        let reason = self.partial.stop_reason;
524        self.pending_events.push_back(StreamEvent::Done {
525            reason,
526            message: std::mem::take(&mut self.partial),
527        });
528        self.finished = true;
529    }
530}
531
532// ============================================================================
533// Cohere API Types (minimal)
534// ============================================================================
535
536#[derive(Debug, Serialize)]
537pub struct CohereRequest {
538    model: String,
539    messages: Vec<CohereMessage>,
540    #[serde(skip_serializing_if = "Option::is_none")]
541    max_tokens: Option<u32>,
542    #[serde(skip_serializing_if = "Option::is_none")]
543    temperature: Option<f32>,
544    #[serde(skip_serializing_if = "Option::is_none")]
545    tools: Option<Vec<CohereTool>>,
546    stream: bool,
547}
548
549#[derive(Debug, Serialize)]
550#[serde(tag = "role", rename_all = "lowercase")]
551enum CohereMessage {
552    System {
553        content: String,
554    },
555    User {
556        content: String,
557    },
558    Assistant {
559        #[serde(skip_serializing_if = "Option::is_none")]
560        content: Option<String>,
561        #[serde(skip_serializing_if = "Option::is_none")]
562        tool_calls: Option<Vec<CohereToolCallRef>>,
563        #[serde(skip_serializing_if = "Option::is_none")]
564        tool_plan: Option<String>,
565    },
566    Tool {
567        content: String,
568        tool_call_id: String,
569    },
570}
571
572#[derive(Debug, Serialize)]
573struct CohereToolCallRef {
574    id: String,
575    #[serde(rename = "type")]
576    r#type: &'static str,
577    function: CohereFunctionRef,
578}
579
580#[derive(Debug, Serialize)]
581struct CohereFunctionRef {
582    name: String,
583    arguments: String,
584}
585
586#[derive(Debug, Serialize)]
587struct CohereTool {
588    #[serde(rename = "type")]
589    r#type: &'static str,
590    function: CohereFunction,
591}
592
593#[derive(Debug, Serialize)]
594struct CohereFunction {
595    name: String,
596    #[serde(skip_serializing_if = "Option::is_none")]
597    description: Option<String>,
598    parameters: serde_json::Value,
599}
600
601fn convert_tool_to_cohere(tool: &ToolDef) -> CohereTool {
602    CohereTool {
603        r#type: "function",
604        function: CohereFunction {
605            name: tool.name.clone(),
606            description: if tool.description.trim().is_empty() {
607                None
608            } else {
609                Some(tool.description.clone())
610            },
611            parameters: tool.parameters.clone(),
612        },
613    }
614}
615
616fn build_cohere_messages(context: &Context<'_>) -> Vec<CohereMessage> {
617    let mut out = Vec::new();
618
619    if let Some(system) = &context.system_prompt {
620        out.push(CohereMessage::System {
621            content: system.to_string(),
622        });
623    }
624
625    for message in context.messages.iter() {
626        match message {
627            Message::User(user) => out.push(CohereMessage::User {
628                content: extract_text_user_content(&user.content),
629            }),
630            Message::Custom(custom) => out.push(CohereMessage::User {
631                content: custom.content.clone(),
632            }),
633            Message::Assistant(assistant) => {
634                let mut text = String::new();
635                let mut tool_calls = Vec::new();
636
637                for block in &assistant.content {
638                    match block {
639                        ContentBlock::Text(t) => text.push_str(&t.text),
640                        ContentBlock::ToolCall(tc) => tool_calls.push(CohereToolCallRef {
641                            id: tc.id.clone(),
642                            r#type: "function",
643                            function: CohereFunctionRef {
644                                name: tc.name.clone(),
645                                arguments: tc.arguments.to_string(),
646                            },
647                        }),
648                        _ => {}
649                    }
650                }
651
652                out.push(CohereMessage::Assistant {
653                    content: if tool_calls.is_empty() && !text.is_empty() {
654                        Some(text)
655                    } else {
656                        None
657                    },
658                    tool_calls: if tool_calls.is_empty() {
659                        None
660                    } else {
661                        Some(tool_calls)
662                    },
663                    tool_plan: None,
664                });
665            }
666            Message::ToolResult(result) => {
667                let mut content = String::new();
668                for (i, block) in result.content.iter().enumerate() {
669                    if i > 0 {
670                        content.push('\n');
671                    }
672                    if let ContentBlock::Text(t) = block {
673                        content.push_str(&t.text);
674                    }
675                }
676                out.push(CohereMessage::Tool {
677                    content,
678                    tool_call_id: result.tool_call_id.clone(),
679                });
680            }
681        }
682    }
683
684    out
685}
686
687fn extract_text_user_content(content: &UserContent) -> String {
688    match content {
689        UserContent::Text(text) => text.clone(),
690        UserContent::Blocks(blocks) => {
691            let mut out = String::new();
692            for block in blocks {
693                match block {
694                    ContentBlock::Text(t) => out.push_str(&t.text),
695                    ContentBlock::Image(img) => {
696                        use std::fmt::Write as _;
697                        let _ =
698                            write!(out, "[Image: {} ({} bytes)]", img.mime_type, img.data.len());
699                    }
700                    _ => {}
701                }
702            }
703            out
704        }
705    }
706}
707
708// ============================================================================
709// Cohere streaming chunk types (minimal, forward-compatible)
710// ============================================================================
711
712#[derive(Debug, Deserialize)]
713#[serde(tag = "type")]
714enum CohereStreamChunk {
715    #[serde(rename = "message-start")]
716    MessageStart { id: Option<String> },
717    #[serde(rename = "content-start")]
718    ContentStart {
719        index: u32,
720        delta: CohereContentStartDelta,
721    },
722    #[serde(rename = "content-delta")]
723    ContentDelta {
724        index: u32,
725        delta: CohereContentDelta,
726    },
727    #[serde(rename = "content-end")]
728    ContentEnd { index: u32 },
729    #[serde(rename = "tool-call-start")]
730    ToolCallStart { delta: CohereToolCallStartDelta },
731    #[serde(rename = "tool-call-delta")]
732    ToolCallDelta { delta: CohereToolCallDelta },
733    #[serde(rename = "tool-call-end")]
734    ToolCallEnd,
735    #[serde(rename = "message-end")]
736    MessageEnd { delta: CohereMessageEndDelta },
737    #[serde(other)]
738    Unknown,
739}
740
741#[derive(Debug, Deserialize)]
742struct CohereContentStartDelta {
743    message: CohereDeltaMessage<CohereContentStart>,
744}
745
746#[derive(Debug, Deserialize)]
747struct CohereContentDelta {
748    message: CohereDeltaMessage<CohereContentDeltaPart>,
749}
750
751#[derive(Debug, Deserialize)]
752struct CohereDeltaMessage<T> {
753    content: T,
754}
755
756#[derive(Debug, Deserialize)]
757#[serde(tag = "type")]
758enum CohereContentStart {
759    #[serde(rename = "text")]
760    Text { text: String },
761    #[serde(rename = "thinking")]
762    Thinking { thinking: String },
763}
764
765#[derive(Debug, Deserialize)]
766#[serde(untagged)]
767enum CohereContentDeltaPart {
768    Text { text: String },
769    Thinking { thinking: String },
770}
771
772#[derive(Debug, Clone, Copy)]
773enum CohereContentKind {
774    Text,
775    Thinking,
776}
777
778impl CohereContentStart {
779    fn kind_and_text(self) -> (CohereContentKind, String) {
780        match self {
781            Self::Text { text } => (CohereContentKind::Text, text),
782            Self::Thinking { thinking } => (CohereContentKind::Thinking, thinking),
783        }
784    }
785}
786
787impl CohereContentDeltaPart {
788    fn kind_and_text(self) -> (CohereContentKind, String) {
789        match self {
790            Self::Text { text } => (CohereContentKind::Text, text),
791            Self::Thinking { thinking } => (CohereContentKind::Thinking, thinking),
792        }
793    }
794}
795
796#[derive(Debug, Deserialize)]
797struct CohereToolCallStartDelta {
798    message: CohereToolCallMessage<CohereToolCallStartBody>,
799}
800
801#[derive(Debug, Deserialize)]
802struct CohereToolCallDelta {
803    message: CohereToolCallMessage<CohereToolCallDeltaBody>,
804}
805
806#[derive(Debug, Deserialize)]
807struct CohereToolCallMessage<T> {
808    tool_calls: T,
809}
810
811#[derive(Debug, Deserialize)]
812struct CohereToolCallStartBody {
813    id: String,
814    function: CohereToolCallFunctionStart,
815}
816
817#[derive(Debug, Deserialize)]
818struct CohereToolCallFunctionStart {
819    name: String,
820    arguments: String,
821}
822
823#[derive(Debug, Deserialize)]
824struct CohereToolCallDeltaBody {
825    function: CohereToolCallFunctionDelta,
826}
827
828#[derive(Debug, Deserialize)]
829struct CohereToolCallFunctionDelta {
830    arguments: String,
831}
832
833#[derive(Debug, Deserialize)]
834struct CohereMessageEndDelta {
835    finish_reason: String,
836    usage: CohereUsage,
837}
838
839#[derive(Debug, Deserialize)]
840struct CohereUsage {
841    tokens: CohereUsageTokens,
842}
843
844#[derive(Debug, Deserialize)]
845struct CohereUsageTokens {
846    input_tokens: u64,
847    output_tokens: u64,
848}
849
850// ============================================================================
851// Tests
852// ============================================================================
853
854#[cfg(test)]
855mod tests {
856    use super::*;
857    use asupersync::runtime::RuntimeBuilder;
858    use futures::stream;
859    use serde_json::{Value, json};
860    use std::collections::HashMap;
861    use std::io::{Read, Write};
862    use std::net::TcpListener;
863    use std::path::PathBuf;
864    use std::sync::mpsc;
865    use std::time::Duration;
866
867    // ─── Fixture infrastructure ─────────────────────────────────────────
868
869    #[derive(Debug, Deserialize)]
870    struct ProviderFixture {
871        cases: Vec<ProviderCase>,
872    }
873
874    #[derive(Debug, Deserialize)]
875    struct ProviderCase {
876        name: String,
877        events: Vec<Value>,
878        expected: Vec<EventSummary>,
879    }
880
881    #[derive(Debug, Deserialize, Serialize, PartialEq)]
882    struct EventSummary {
883        kind: String,
884        #[serde(default)]
885        content_index: Option<usize>,
886        #[serde(default)]
887        delta: Option<String>,
888        #[serde(default)]
889        content: Option<String>,
890        #[serde(default)]
891        reason: Option<String>,
892    }
893
894    fn load_fixture(file_name: &str) -> ProviderFixture {
895        let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
896            .join("tests/fixtures/provider_responses")
897            .join(file_name);
898        let raw = std::fs::read_to_string(path).expect("fixture read");
899        serde_json::from_str(&raw).expect("fixture parse")
900    }
901
902    fn summarize_event(event: &StreamEvent) -> EventSummary {
903        match event {
904            StreamEvent::Start { .. } => EventSummary {
905                kind: "start".to_string(),
906                content_index: None,
907                delta: None,
908                content: None,
909                reason: None,
910            },
911            StreamEvent::TextStart { content_index, .. } => EventSummary {
912                kind: "text_start".to_string(),
913                content_index: Some(*content_index),
914                delta: None,
915                content: None,
916                reason: None,
917            },
918            StreamEvent::TextDelta {
919                content_index,
920                delta,
921                ..
922            } => EventSummary {
923                kind: "text_delta".to_string(),
924                content_index: Some(*content_index),
925                delta: Some(delta.clone()),
926                content: None,
927                reason: None,
928            },
929            StreamEvent::TextEnd {
930                content_index,
931                content,
932                ..
933            } => EventSummary {
934                kind: "text_end".to_string(),
935                content_index: Some(*content_index),
936                delta: None,
937                content: Some(content.clone()),
938                reason: None,
939            },
940            StreamEvent::Done { reason, .. } => EventSummary {
941                kind: "done".to_string(),
942                content_index: None,
943                delta: None,
944                content: None,
945                reason: Some(reason_to_string(*reason)),
946            },
947            StreamEvent::Error { reason, .. } => EventSummary {
948                kind: "error".to_string(),
949                content_index: None,
950                delta: None,
951                content: None,
952                reason: Some(reason_to_string(*reason)),
953            },
954            _ => EventSummary {
955                kind: "other".to_string(),
956                content_index: None,
957                delta: None,
958                content: None,
959                reason: None,
960            },
961        }
962    }
963
964    fn reason_to_string(reason: StopReason) -> String {
965        match reason {
966            StopReason::Stop => "stop",
967            StopReason::Length => "length",
968            StopReason::ToolUse => "tool_use",
969            StopReason::Error => "error",
970            StopReason::Aborted => "aborted",
971        }
972        .to_string()
973    }
974
975    #[test]
976    fn test_stream_fixtures() {
977        let fixture = load_fixture("cohere_stream.json");
978        for case in fixture.cases {
979            let events = collect_events(&case.events);
980            let summaries: Vec<EventSummary> = events.iter().map(summarize_event).collect();
981            assert_eq!(summaries, case.expected, "case {}", case.name);
982        }
983    }
984
985    // ─── Existing tests ─────────────────────────────────────────────────
986
987    #[test]
988    fn test_provider_info() {
989        let provider = CohereProvider::new("command-r");
990        assert_eq!(provider.name(), "cohere");
991        assert_eq!(provider.api(), "cohere-chat");
992    }
993
994    #[test]
995    fn test_build_request_includes_system_tools_and_v2_shape() {
996        let provider = CohereProvider::new("command-r");
997        let context = Context::owned(
998            Some("You are concise.".to_string()),
999            vec![Message::User(crate::model::UserMessage {
1000                content: UserContent::Text("Ping".to_string()),
1001                timestamp: 0,
1002            })],
1003            vec![ToolDef {
1004                name: "search".to_string(),
1005                description: "Search docs".to_string(),
1006                parameters: json!({
1007                    "type": "object",
1008                    "properties": {
1009                        "q": { "type": "string" }
1010                    },
1011                    "required": ["q"]
1012                }),
1013            }],
1014        );
1015        let options = StreamOptions {
1016            temperature: Some(0.2),
1017            max_tokens: Some(123),
1018            ..Default::default()
1019        };
1020
1021        let request = provider.build_request(&context, &options);
1022        let value = serde_json::to_value(&request).expect("serialize request");
1023
1024        assert_eq!(value["model"], "command-r");
1025        assert_eq!(value["messages"][0]["role"], "system");
1026        assert_eq!(value["messages"][0]["content"], "You are concise.");
1027        assert_eq!(value["messages"][1]["role"], "user");
1028        assert_eq!(value["messages"][1]["content"], "Ping");
1029        assert_eq!(value["stream"], true);
1030        assert_eq!(value["max_tokens"], 123);
1031        let temperature = value["temperature"]
1032            .as_f64()
1033            .expect("temperature should be numeric");
1034        assert!((temperature - 0.2).abs() < 1e-6);
1035        assert_eq!(value["tools"][0]["type"], "function");
1036        assert_eq!(value["tools"][0]["function"]["name"], "search");
1037        assert_eq!(value["tools"][0]["function"]["description"], "Search docs");
1038        assert_eq!(
1039            value["tools"][0]["function"]["parameters"],
1040            json!({
1041                "type": "object",
1042                "properties": {
1043                    "q": { "type": "string" }
1044                },
1045                "required": ["q"]
1046            })
1047        );
1048    }
1049
1050    #[test]
1051    fn test_convert_tool_to_cohere_omits_empty_description() {
1052        let tool = ToolDef {
1053            name: "echo".to_string(),
1054            description: "   ".to_string(),
1055            parameters: json!({
1056                "type": "object",
1057                "properties": {
1058                    "text": { "type": "string" }
1059                }
1060            }),
1061        };
1062
1063        let converted = convert_tool_to_cohere(&tool);
1064        let value = serde_json::to_value(converted).expect("serialize converted tool");
1065        assert_eq!(value["type"], "function");
1066        assert_eq!(value["function"]["name"], "echo");
1067        assert!(value["function"].get("description").is_none());
1068    }
1069
1070    #[test]
1071    fn test_stream_parses_text_and_tool_call() {
1072        let runtime = RuntimeBuilder::current_thread()
1073            .build()
1074            .expect("runtime build");
1075
1076        runtime.block_on(async move {
1077            let events = [
1078                serde_json::json!({ "type": "message-start", "id": "msg_1" }),
1079                serde_json::json!({
1080                    "type": "content-start",
1081                    "index": 0,
1082                    "delta": { "message": { "content": { "type": "text", "text": "Hello" } } }
1083                }),
1084                serde_json::json!({
1085                    "type": "content-delta",
1086                    "index": 0,
1087                    "delta": { "message": { "content": { "text": " world" } } }
1088                }),
1089                serde_json::json!({ "type": "content-end", "index": 0 }),
1090                serde_json::json!({
1091                    "type": "tool-call-start",
1092                    "delta": { "message": { "tool_calls": { "id": "call_1", "type": "function", "function": { "name": "echo", "arguments": "{\"text\":\"hi\"}" } } } }
1093                }),
1094                serde_json::json!({ "type": "tool-call-end" }),
1095                serde_json::json!({
1096                    "type": "message-end",
1097                    "delta": { "finish_reason": "TOOL_CALL", "usage": { "tokens": { "input_tokens": 1, "output_tokens": 2 } } }
1098                }),
1099            ];
1100
1101            let byte_stream = stream::iter(
1102                events
1103                    .iter()
1104                    .map(|event| format!("data: {}\n\n", serde_json::to_string(event).unwrap()))
1105                    .map(|s| Ok(s.into_bytes())),
1106            );
1107
1108            let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1109            let mut state = StreamState::new(
1110                event_source,
1111                "command-r".to_string(),
1112                "cohere-chat".to_string(),
1113                "cohere".to_string(),
1114            );
1115
1116            let mut out = Vec::new();
1117            while let Some(item) = state.event_source.next().await {
1118                let msg = item.expect("SSE event");
1119                state.process_event(&msg.data).expect("process_event");
1120                out.extend(state.pending_events.drain(..));
1121                if state.finished {
1122                    break;
1123                }
1124            }
1125
1126            assert!(matches!(out.first(), Some(StreamEvent::Start { .. })));
1127            assert!(out.iter().any(|e| matches!(e, StreamEvent::TextDelta { delta, .. } if delta.contains("Hello"))));
1128            assert!(out.iter().any(|e| matches!(e, StreamEvent::ToolCallEnd { tool_call, .. } if tool_call.name == "echo")));
1129            assert!(out.iter().any(|e| matches!(e, StreamEvent::Done { reason: StopReason::ToolUse, .. })));
1130        });
1131    }
1132
1133    #[test]
1134    fn test_stream_parses_thinking_and_max_tokens_stop_reason() {
1135        let events = vec![
1136            json!({ "type": "message-start", "id": "msg_1" }),
1137            json!({
1138                "type": "content-start",
1139                "index": 0,
1140                "delta": { "message": { "content": { "type": "thinking", "thinking": "Plan" } } }
1141            }),
1142            json!({
1143                "type": "content-delta",
1144                "index": 0,
1145                "delta": { "message": { "content": { "thinking": " more" } } }
1146            }),
1147            json!({ "type": "content-end", "index": 0 }),
1148            json!({
1149                "type": "message-end",
1150                "delta": { "finish_reason": "MAX_TOKENS", "usage": { "tokens": { "input_tokens": 2, "output_tokens": 3 } } }
1151            }),
1152        ];
1153
1154        let out = collect_events(&events);
1155        assert!(
1156            out.iter()
1157                .any(|e| matches!(e, StreamEvent::ThinkingStart { .. }))
1158        );
1159        assert!(out.iter().any(
1160            |e| matches!(e, StreamEvent::ThinkingDelta { delta, .. } if delta.contains("Plan"))
1161        ));
1162        assert!(
1163            out.iter()
1164                .any(|e| matches!(e, StreamEvent::ThinkingEnd { content, .. } if content.contains("Plan more")))
1165        );
1166        assert!(out.iter().any(|e| matches!(
1167            e,
1168            StreamEvent::Done {
1169                reason: StopReason::Length,
1170                ..
1171            }
1172        )));
1173    }
1174
1175    #[test]
1176    fn test_stream_sets_bearer_auth_header() {
1177        let captured = run_stream_and_capture_headers(Some("test-cohere-key"), HashMap::new())
1178            .expect("captured request");
1179        assert_eq!(
1180            captured.headers.get("authorization").map(String::as_str),
1181            Some("Bearer test-cohere-key")
1182        );
1183        assert_eq!(
1184            captured.headers.get("accept").map(String::as_str),
1185            Some("text/event-stream")
1186        );
1187
1188        let body: Value = serde_json::from_str(&captured.body).expect("body json");
1189        assert_eq!(body["model"], "command-r");
1190        assert_eq!(body["stream"], true);
1191    }
1192
1193    #[test]
1194    fn test_stream_uses_existing_authorization_header_without_api_key() {
1195        let mut headers = HashMap::new();
1196        headers.insert(
1197            "Authorization".to_string(),
1198            "Bearer from-custom-header".to_string(),
1199        );
1200        headers.insert("X-Test".to_string(), "1".to_string());
1201
1202        let captured = run_stream_and_capture_headers(None, headers).expect("captured request");
1203        assert_eq!(
1204            captured.headers.get("authorization").map(String::as_str),
1205            Some("Bearer from-custom-header")
1206        );
1207        assert_eq!(
1208            captured.headers.get("x-test").map(String::as_str),
1209            Some("1")
1210        );
1211    }
1212
1213    fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
1214        let runtime = RuntimeBuilder::current_thread()
1215            .build()
1216            .expect("runtime build");
1217
1218        runtime.block_on(async {
1219            let byte_stream = stream::iter(
1220                events
1221                    .iter()
1222                    .map(|event| {
1223                        format!(
1224                            "data: {}\n\n",
1225                            serde_json::to_string(event).expect("serialize event")
1226                        )
1227                    })
1228                    .map(|s| Ok(s.into_bytes())),
1229            );
1230
1231            let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1232            let mut state = StreamState::new(
1233                event_source,
1234                "command-r".to_string(),
1235                "cohere-chat".to_string(),
1236                "cohere".to_string(),
1237            );
1238
1239            let mut out = Vec::new();
1240            while let Some(item) = state.event_source.next().await {
1241                let msg = item.expect("SSE event");
1242                state.process_event(&msg.data).expect("process_event");
1243                out.extend(state.pending_events.drain(..));
1244                if state.finished {
1245                    break;
1246                }
1247            }
1248            out
1249        })
1250    }
1251
1252    #[derive(Debug)]
1253    struct CapturedRequest {
1254        headers: HashMap<String, String>,
1255        body: String,
1256    }
1257
1258    fn run_stream_and_capture_headers(
1259        api_key: Option<&str>,
1260        extra_headers: HashMap<String, String>,
1261    ) -> Option<CapturedRequest> {
1262        let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1263        let provider = CohereProvider::new("command-r").with_base_url(base_url);
1264        let context = Context::owned(
1265            Some("system".to_string()),
1266            vec![Message::User(crate::model::UserMessage {
1267                content: UserContent::Text("ping".to_string()),
1268                timestamp: 0,
1269            })],
1270            Vec::new(),
1271        );
1272        let options = StreamOptions {
1273            api_key: api_key.map(str::to_string),
1274            headers: extra_headers,
1275            ..Default::default()
1276        };
1277
1278        let runtime = RuntimeBuilder::current_thread()
1279            .build()
1280            .expect("runtime build");
1281        runtime.block_on(async {
1282            let mut stream = provider.stream(&context, &options).await.expect("stream");
1283            while let Some(event) = stream.next().await {
1284                if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1285                    break;
1286                }
1287            }
1288        });
1289
1290        rx.recv_timeout(Duration::from_secs(2)).ok()
1291    }
1292
1293    fn success_sse_body() -> String {
1294        [
1295            r#"data: {"type":"message-start","id":"msg_1"}"#,
1296            "",
1297            r#"data: {"type":"message-end","delta":{"finish_reason":"COMPLETE","usage":{"tokens":{"input_tokens":1,"output_tokens":1}}}}"#,
1298            "",
1299        ]
1300        .join("\n")
1301    }
1302
1303    fn spawn_test_server(
1304        status_code: u16,
1305        content_type: &str,
1306        body: &str,
1307    ) -> (String, mpsc::Receiver<CapturedRequest>) {
1308        let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
1309        let addr = listener.local_addr().expect("local addr");
1310        let (tx, rx) = mpsc::channel();
1311        let body = body.to_string();
1312        let content_type = content_type.to_string();
1313
1314        std::thread::spawn(move || {
1315            let (mut socket, _) = listener.accept().expect("accept");
1316            socket
1317                .set_read_timeout(Some(Duration::from_secs(2)))
1318                .expect("set read timeout");
1319
1320            let mut bytes = Vec::new();
1321            let mut chunk = [0_u8; 4096];
1322            loop {
1323                match socket.read(&mut chunk) {
1324                    Ok(0) => break,
1325                    Ok(n) => {
1326                        bytes.extend_from_slice(&chunk[..n]);
1327                        if bytes.windows(4).any(|window| window == b"\r\n\r\n") {
1328                            break;
1329                        }
1330                    }
1331                    Err(err)
1332                        if err.kind() == std::io::ErrorKind::WouldBlock
1333                            || err.kind() == std::io::ErrorKind::TimedOut =>
1334                    {
1335                        break;
1336                    }
1337                    Err(err) => panic!("read request failed: {err}"),
1338                }
1339            }
1340
1341            let header_end = bytes
1342                .windows(4)
1343                .position(|window| window == b"\r\n\r\n")
1344                .expect("request header boundary");
1345            let header_text = String::from_utf8_lossy(&bytes[..header_end]).to_string();
1346            let headers = parse_headers(&header_text);
1347            let mut request_body = bytes[header_end + 4..].to_vec();
1348
1349            let content_length = headers
1350                .get("content-length")
1351                .and_then(|value| value.parse::<usize>().ok())
1352                .unwrap_or(0);
1353            while request_body.len() < content_length {
1354                match socket.read(&mut chunk) {
1355                    Ok(0) => break,
1356                    Ok(n) => request_body.extend_from_slice(&chunk[..n]),
1357                    Err(err)
1358                        if err.kind() == std::io::ErrorKind::WouldBlock
1359                            || err.kind() == std::io::ErrorKind::TimedOut =>
1360                    {
1361                        break;
1362                    }
1363                    Err(err) => panic!("read request body failed: {err}"),
1364                }
1365            }
1366
1367            let captured = CapturedRequest {
1368                headers,
1369                body: String::from_utf8_lossy(&request_body).to_string(),
1370            };
1371            tx.send(captured).expect("send captured request");
1372
1373            let reason = match status_code {
1374                401 => "Unauthorized",
1375                500 => "Internal Server Error",
1376                _ => "OK",
1377            };
1378            let response = format!(
1379                "HTTP/1.1 {status_code} {reason}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
1380                body.len()
1381            );
1382            socket
1383                .write_all(response.as_bytes())
1384                .expect("write response");
1385            socket.flush().expect("flush response");
1386        });
1387
1388        (format!("http://{addr}"), rx)
1389    }
1390
1391    fn parse_headers(header_text: &str) -> HashMap<String, String> {
1392        let mut headers = HashMap::new();
1393        for line in header_text.lines().skip(1) {
1394            if let Some((name, value)) = line.split_once(':') {
1395                headers.insert(name.trim().to_ascii_lowercase(), value.trim().to_string());
1396            }
1397        }
1398        headers
1399    }
1400
1401    // ─── Request body format tests ──────────────────────────────────────
1402
1403    #[test]
1404    fn test_build_request_no_system_prompt() {
1405        let provider = CohereProvider::new("command-r-plus");
1406        let context = Context::owned(
1407            None,
1408            vec![Message::User(crate::model::UserMessage {
1409                content: UserContent::Text("Hi".to_string()),
1410                timestamp: 0,
1411            })],
1412            vec![],
1413        );
1414        let options = StreamOptions::default();
1415
1416        let req = provider.build_request(&context, &options);
1417        let value = serde_json::to_value(&req).expect("serialize");
1418
1419        // First message should be user, no system message.
1420        assert_eq!(value["messages"][0]["role"], "user");
1421        assert_eq!(value["messages"][0]["content"], "Hi");
1422        // No system role message at all.
1423        let msgs = value["messages"].as_array().unwrap();
1424        assert!(
1425            !msgs.iter().any(|m| m["role"] == "system"),
1426            "No system message should be present"
1427        );
1428    }
1429
1430    #[test]
1431    fn test_build_request_default_max_tokens() {
1432        let provider = CohereProvider::new("command-r");
1433        let context = Context::owned(
1434            None,
1435            vec![Message::User(crate::model::UserMessage {
1436                content: UserContent::Text("test".to_string()),
1437                timestamp: 0,
1438            })],
1439            vec![],
1440        );
1441        let options = StreamOptions::default();
1442
1443        let req = provider.build_request(&context, &options);
1444        let value = serde_json::to_value(&req).expect("serialize");
1445
1446        assert_eq!(value["max_tokens"], DEFAULT_MAX_TOKENS);
1447    }
1448
1449    #[test]
1450    fn test_build_request_no_tools_omits_tools_field() {
1451        let provider = CohereProvider::new("command-r");
1452        let context = Context::owned(
1453            None,
1454            vec![Message::User(crate::model::UserMessage {
1455                content: UserContent::Text("test".to_string()),
1456                timestamp: 0,
1457            })],
1458            vec![],
1459        );
1460        let options = StreamOptions::default();
1461
1462        let req = provider.build_request(&context, &options);
1463        let value = serde_json::to_value(&req).expect("serialize");
1464
1465        assert!(
1466            value.get("tools").is_none() || value["tools"].is_null(),
1467            "tools field should be omitted when empty"
1468        );
1469    }
1470
1471    #[test]
1472    fn test_build_request_full_conversation_with_tool_call_and_result() {
1473        let provider = CohereProvider::new("command-r");
1474        let context = Context::owned(
1475            Some("Be concise.".to_string()),
1476            vec![
1477                Message::User(crate::model::UserMessage {
1478                    content: UserContent::Text("Read /tmp/a.txt".to_string()),
1479                    timestamp: 0,
1480                }),
1481                Message::assistant(AssistantMessage {
1482                    content: vec![ContentBlock::ToolCall(ToolCall {
1483                        id: "call_1".to_string(),
1484                        name: "read".to_string(),
1485                        arguments: serde_json::json!({"path": "/tmp/a.txt"}),
1486                        thought_signature: None,
1487                    })],
1488                    api: "cohere-chat".to_string(),
1489                    provider: "cohere".to_string(),
1490                    model: "command-r".to_string(),
1491                    usage: Usage::default(),
1492                    stop_reason: StopReason::ToolUse,
1493                    error_message: None,
1494                    timestamp: 1,
1495                }),
1496                Message::tool_result(crate::model::ToolResultMessage {
1497                    tool_call_id: "call_1".to_string(),
1498                    tool_name: "read".to_string(),
1499                    content: vec![ContentBlock::Text(TextContent::new("file contents"))],
1500                    details: None,
1501                    is_error: false,
1502                    timestamp: 2,
1503                }),
1504            ],
1505            vec![ToolDef {
1506                name: "read".to_string(),
1507                description: "Read a file".to_string(),
1508                parameters: json!({"type": "object"}),
1509            }],
1510        );
1511        let options = StreamOptions::default();
1512
1513        let req = provider.build_request(&context, &options);
1514        let value = serde_json::to_value(&req).expect("serialize");
1515
1516        let msgs = value["messages"].as_array().unwrap();
1517        // system, user, assistant, tool
1518        assert_eq!(msgs.len(), 4);
1519        assert_eq!(msgs[0]["role"], "system");
1520        assert_eq!(msgs[1]["role"], "user");
1521        assert_eq!(msgs[2]["role"], "assistant");
1522        assert_eq!(msgs[3]["role"], "tool");
1523
1524        // Assistant message should have tool_calls, not content text.
1525        assert!(msgs[2].get("content").is_none() || msgs[2]["content"].is_null());
1526        let tool_calls = msgs[2]["tool_calls"].as_array().unwrap();
1527        assert_eq!(tool_calls.len(), 1);
1528        assert_eq!(tool_calls[0]["id"], "call_1");
1529        assert_eq!(tool_calls[0]["type"], "function");
1530        assert_eq!(tool_calls[0]["function"]["name"], "read");
1531
1532        // Tool result should reference the tool_call_id.
1533        assert_eq!(msgs[3]["tool_call_id"], "call_1");
1534        assert_eq!(msgs[3]["content"], "file contents");
1535    }
1536
1537    #[test]
1538    fn test_convert_custom_message_to_cohere() {
1539        let context = Context::owned(
1540            None,
1541            vec![Message::Custom(crate::model::CustomMessage {
1542                custom_type: "extension_note".to_string(),
1543                content: "Important context.".to_string(),
1544                display: false,
1545                details: None,
1546                timestamp: 0,
1547            })],
1548            vec![],
1549        );
1550
1551        let msgs = build_cohere_messages(&context);
1552        assert_eq!(msgs.len(), 1);
1553        let value = serde_json::to_value(&msgs[0]).expect("serialize");
1554        assert_eq!(value["role"], "user");
1555        assert_eq!(value["content"], "Important context.");
1556    }
1557
1558    #[test]
1559    fn test_convert_user_blocks_extracts_text_only() {
1560        let content = UserContent::Blocks(vec![
1561            ContentBlock::Text(TextContent::new("part 1")),
1562            ContentBlock::Image(crate::model::ImageContent {
1563                data: "aGVsbG8=".to_string(),
1564                mime_type: "image/png".to_string(),
1565            }),
1566            ContentBlock::Text(TextContent::new("part 2")),
1567        ]);
1568
1569        let text = extract_text_user_content(&content);
1570        assert_eq!(text, "part 1[Image: image/png (8 bytes)]part 2");
1571    }
1572
1573    // ─── Provider builder tests ─────────────────────────────────────────
1574
1575    #[test]
1576    fn test_custom_provider_name() {
1577        let provider = CohereProvider::new("command-r").with_provider_name("my-proxy");
1578        assert_eq!(provider.name(), "my-proxy");
1579        assert_eq!(provider.api(), "cohere-chat");
1580    }
1581
1582    #[test]
1583    fn test_custom_base_url() {
1584        let provider =
1585            CohereProvider::new("command-r").with_base_url("https://proxy.example.com/v2/chat");
1586        assert_eq!(provider.base_url, "https://proxy.example.com/v2/chat");
1587    }
1588
1589    // ─── Stream event parsing tests ─────────────────────────────────────
1590
1591    #[test]
1592    fn test_stream_complete_finish_reason_maps_to_stop() {
1593        let events = vec![
1594            json!({ "type": "message-start", "id": "msg_1" }),
1595            json!({
1596                "type": "message-end",
1597                "delta": {
1598                    "finish_reason": "COMPLETE",
1599                    "usage": { "tokens": { "input_tokens": 5, "output_tokens": 10 } }
1600                }
1601            }),
1602        ];
1603
1604        let out = collect_events(&events);
1605        assert!(out.iter().any(|e| matches!(
1606            e,
1607            StreamEvent::Done {
1608                reason: StopReason::Stop,
1609                message,
1610                ..
1611            } if message.usage.input == 5 && message.usage.output == 10
1612        )));
1613    }
1614
1615    #[test]
1616    fn test_stream_error_finish_reason_maps_to_error() {
1617        let events = vec![
1618            json!({ "type": "message-start", "id": "msg_1" }),
1619            json!({
1620                "type": "message-end",
1621                "delta": {
1622                    "finish_reason": "ERROR",
1623                    "usage": { "tokens": { "input_tokens": 1, "output_tokens": 0 } }
1624                }
1625            }),
1626        ];
1627
1628        let out = collect_events(&events);
1629        assert!(out.iter().any(|e| matches!(
1630            e,
1631            StreamEvent::Done {
1632                reason: StopReason::Error,
1633                ..
1634            }
1635        )));
1636    }
1637
1638    #[test]
1639    fn test_stream_tool_call_with_streamed_arguments() {
1640        let events = vec![
1641            json!({ "type": "message-start", "id": "msg_1" }),
1642            json!({
1643                "type": "tool-call-start",
1644                "delta": {
1645                    "message": {
1646                        "tool_calls": {
1647                            "id": "call_42",
1648                            "type": "function",
1649                            "function": { "name": "bash", "arguments": "{\"co" }
1650                        }
1651                    }
1652                }
1653            }),
1654            json!({
1655                "type": "tool-call-delta",
1656                "delta": {
1657                    "message": {
1658                        "tool_calls": {
1659                            "function": { "arguments": "mmand\"" }
1660                        }
1661                    }
1662                }
1663            }),
1664            json!({
1665                "type": "tool-call-delta",
1666                "delta": {
1667                    "message": {
1668                        "tool_calls": {
1669                            "function": { "arguments": ": \"ls -la\"}" }
1670                        }
1671                    }
1672                }
1673            }),
1674            json!({ "type": "tool-call-end" }),
1675            json!({
1676                "type": "message-end",
1677                "delta": {
1678                    "finish_reason": "TOOL_CALL",
1679                    "usage": { "tokens": { "input_tokens": 10, "output_tokens": 20 } }
1680                }
1681            }),
1682        ];
1683
1684        let out = collect_events(&events);
1685
1686        // Should have ToolCallEnd with properly assembled arguments.
1687        let tool_end = out
1688            .iter()
1689            .find(|e| matches!(e, StreamEvent::ToolCallEnd { .. }));
1690        assert!(tool_end.is_some(), "Expected ToolCallEnd event");
1691        if let Some(StreamEvent::ToolCallEnd { tool_call, .. }) = tool_end {
1692            assert_eq!(tool_call.name, "bash");
1693            assert_eq!(tool_call.id, "call_42");
1694            assert_eq!(tool_call.arguments["command"], "ls -la");
1695        }
1696    }
1697
1698    #[test]
1699    fn test_stream_unknown_event_type_ignored() {
1700        let events = vec![
1701            json!({ "type": "message-start", "id": "msg_1" }),
1702            json!({ "type": "some-future-event", "data": "ignored" }),
1703            json!({
1704                "type": "content-start",
1705                "index": 0,
1706                "delta": { "message": { "content": { "type": "text", "text": "OK" } } }
1707            }),
1708            json!({ "type": "content-end", "index": 0 }),
1709            json!({
1710                "type": "message-end",
1711                "delta": {
1712                    "finish_reason": "COMPLETE",
1713                    "usage": { "tokens": { "input_tokens": 1, "output_tokens": 1 } }
1714                }
1715            }),
1716        ];
1717
1718        let out = collect_events(&events);
1719        // Should complete successfully despite unknown event.
1720        assert!(out.iter().any(|e| matches!(e, StreamEvent::Done { .. })));
1721        assert!(out.iter().any(|e| matches!(
1722            e,
1723            StreamEvent::TextStart {
1724                content_index: 0,
1725                ..
1726            }
1727        )));
1728    }
1729}
1730
1731// ============================================================================
1732// Fuzzing support
1733// ============================================================================
1734
1735#[cfg(feature = "fuzzing")]
1736pub mod fuzz {
1737    use super::*;
1738    use futures::stream;
1739    use std::pin::Pin;
1740
1741    type FuzzStream =
1742        Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
1743
1744    /// Opaque wrapper around the Cohere stream processor state.
1745    pub struct Processor(StreamState<FuzzStream>);
1746
1747    impl Default for Processor {
1748        fn default() -> Self {
1749            Self::new()
1750        }
1751    }
1752
1753    impl Processor {
1754        /// Create a fresh processor with default state.
1755        pub fn new() -> Self {
1756            let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1757            Self(StreamState::new(
1758                crate::sse::SseStream::new(Box::pin(empty)),
1759                "cohere-fuzz".into(),
1760                "cohere".into(),
1761                "cohere".into(),
1762            ))
1763        }
1764
1765        /// Feed one SSE data payload and return any emitted `StreamEvent`s.
1766        pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
1767            self.0.process_event(data)?;
1768            Ok(self.0.pending_events.drain(..).collect())
1769        }
1770    }
1771}