Skip to main content

pi/providers/
cohere.rs

1//! Cohere Chat API provider implementation.
2//!
3//! This module implements the Provider trait for Cohere's `v2/chat` endpoint,
4//! supporting streaming output text/thinking and function tool calls.
5
6use crate::error::{Error, Result};
7use crate::http::client::Client;
8use crate::model::{
9    AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, TextContent, ThinkingContent,
10    ToolCall, Usage, UserContent,
11};
12use crate::models::CompatConfig;
13use crate::provider::{Context, Provider, StreamOptions, ToolDef};
14use crate::sse::SseStream;
15use async_trait::async_trait;
16use futures::StreamExt;
17use futures::stream::{self, Stream};
18use serde::{Deserialize, Serialize};
19use std::collections::{HashMap, VecDeque};
20use std::pin::Pin;
21
22// ============================================================================
23// Constants
24// ============================================================================
25
26const COHERE_CHAT_API_URL: &str = "https://api.cohere.com/v2/chat";
27const DEFAULT_MAX_TOKENS: u32 = 4096;
28
29// ============================================================================
30// Cohere Provider
31// ============================================================================
32
33/// Cohere `v2/chat` streaming provider.
34pub struct CohereProvider {
35    client: Client,
36    model: String,
37    base_url: String,
38    provider: String,
39    compat: Option<CompatConfig>,
40}
41
42impl CohereProvider {
43    pub fn new(model: impl Into<String>) -> Self {
44        Self {
45            client: Client::new(),
46            model: model.into(),
47            base_url: COHERE_CHAT_API_URL.to_string(),
48            provider: "cohere".to_string(),
49            compat: None,
50        }
51    }
52
53    #[must_use]
54    pub fn with_provider_name(mut self, provider: impl Into<String>) -> Self {
55        self.provider = provider.into();
56        self
57    }
58
59    #[must_use]
60    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
61        self.base_url = base_url.into();
62        self
63    }
64
65    #[must_use]
66    pub fn with_client(mut self, client: Client) -> Self {
67        self.client = client;
68        self
69    }
70
71    /// Attach provider-specific compatibility overrides.
72    #[must_use]
73    pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
74        self.compat = compat;
75        self
76    }
77
78    pub fn build_request(&self, context: &Context<'_>, options: &StreamOptions) -> CohereRequest {
79        let messages = build_cohere_messages(context);
80
81        let tools: Option<Vec<CohereTool>> = if context.tools.is_empty() {
82            None
83        } else {
84            Some(context.tools.iter().map(convert_tool_to_cohere).collect())
85        };
86
87        CohereRequest {
88            model: self.model.clone(),
89            messages,
90            max_tokens: options.max_tokens.or(Some(DEFAULT_MAX_TOKENS)),
91            temperature: options.temperature,
92            tools,
93            stream: true,
94        }
95    }
96}
97
98fn authorization_override(
99    options: &StreamOptions,
100    compat: Option<&CompatConfig>,
101) -> Option<String> {
102    super::first_non_empty_header_value_case_insensitive(&options.headers, &["authorization"])
103        .or_else(|| {
104            compat
105                .and_then(|compat| compat.custom_headers.as_ref())
106                .and_then(|headers| {
107                    super::first_non_empty_header_value_case_insensitive(
108                        headers,
109                        &["authorization"],
110                    )
111                })
112        })
113}
114
115#[async_trait]
116#[allow(clippy::too_many_lines)]
117impl Provider for CohereProvider {
118    fn name(&self) -> &str {
119        &self.provider
120    }
121
122    fn api(&self) -> &'static str {
123        "cohere-chat"
124    }
125
126    fn model_id(&self) -> &str {
127        &self.model
128    }
129
130    async fn stream(
131        &self,
132        context: &Context<'_>,
133        options: &StreamOptions,
134    ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
135        let authorization_override = authorization_override(options, self.compat.as_ref());
136
137        let auth_value = if authorization_override.is_some() {
138            None
139        } else {
140            Some(
141                options
142                    .api_key
143                    .clone()
144                    .or_else(|| std::env::var("COHERE_API_KEY").ok())
145                    .ok_or_else(|| Error::provider("cohere", "Missing API key for provider. Configure credentials with /login <provider> or set the provider's API key env var."))?,
146            )
147        };
148
149        let request_body = self.build_request(context, options);
150
151        // Content-Type set by .json() below
152        let mut request = self
153            .client
154            .post(&self.base_url)
155            .header("Accept", "text/event-stream");
156
157        if let Some(auth_value) = auth_value {
158            request = request.header("Authorization", format!("Bearer {auth_value}"));
159        }
160
161        // Apply provider-specific custom headers from compat config.
162        if let Some(compat) = &self.compat {
163            if let Some(custom_headers) = &compat.custom_headers {
164                request = super::apply_headers_ignoring_blank_auth_overrides(
165                    request,
166                    custom_headers,
167                    &["authorization"],
168                );
169            }
170        }
171
172        // Per-request headers from StreamOptions (highest priority).
173        request = super::apply_headers_ignoring_blank_auth_overrides(
174            request,
175            &options.headers,
176            &["authorization"],
177        );
178
179        let request = request.json(&request_body)?;
180
181        let response = Box::pin(request.send()).await?;
182        let status = response.status();
183        if !(200..300).contains(&status) {
184            let body = response
185                .text()
186                .await
187                .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
188            return Err(Error::provider(
189                "cohere",
190                format!("Cohere API error (HTTP {status}): {body}"),
191            ));
192        }
193
194        let content_type = response
195            .headers()
196            .iter()
197            .find(|(name, _)| name.eq_ignore_ascii_case("content-type"))
198            .map(|(_, value)| value.to_ascii_lowercase());
199        if !content_type
200            .as_deref()
201            .is_some_and(|value| value.contains("text/event-stream"))
202        {
203            let message = content_type.map_or_else(
204                || {
205                    format!(
206                        "Cohere API protocol error (HTTP {status}): missing Content-Type (expected text/event-stream)"
207                    )
208                },
209                |value| {
210                    format!(
211                        "Cohere API protocol error (HTTP {status}): unexpected Content-Type {value} (expected text/event-stream)"
212                    )
213                },
214            );
215            return Err(Error::api(message));
216        }
217
218        let event_source = SseStream::new(response.bytes_stream());
219
220        let model = self.model.clone();
221        let api = self.api().to_string();
222        let provider = self.name().to_string();
223
224        let stream = stream::unfold(
225            StreamState::new(event_source, model, api, provider),
226            |mut state| async move {
227                loop {
228                    if let Some(event) = state.pending_events.pop_front() {
229                        return Some((Ok(event), state));
230                    }
231
232                    if state.finished {
233                        return None;
234                    }
235
236                    match state.event_source.next().await {
237                        Some(Ok(msg)) => {
238                            state.transient_error_count = 0;
239                            if msg.data == "[DONE]" {
240                                state.finish();
241                                continue;
242                            }
243
244                            if let Err(e) = state.process_event(&msg.data) {
245                                state.finished = true;
246                                return Some((Err(e), state));
247                            }
248                        }
249                        Some(Err(e)) => {
250                            // WriteZero, WouldBlock, and TimedOut errors are treated as transient.
251                            // Skip them and keep reading the stream, but cap
252                            // consecutive occurrences to avoid infinite loops.
253                            const MAX_CONSECUTIVE_TRANSIENT_ERRORS: usize = 5;
254                            if e.kind() == std::io::ErrorKind::WriteZero
255                                || e.kind() == std::io::ErrorKind::WouldBlock
256                                || e.kind() == std::io::ErrorKind::TimedOut
257                            {
258                                state.transient_error_count += 1;
259                                if state.transient_error_count <= MAX_CONSECUTIVE_TRANSIENT_ERRORS {
260                                    tracing::warn!(
261                                        kind = ?e.kind(),
262                                        count = state.transient_error_count,
263                                        "Transient error in SSE stream, continuing"
264                                    );
265                                    continue;
266                                }
267                                tracing::warn!(
268                                    kind = ?e.kind(),
269                                    "Error persisted after {MAX_CONSECUTIVE_TRANSIENT_ERRORS} \
270                                     consecutive attempts, treating as fatal"
271                                );
272                            }
273                            state.finished = true;
274                            let err = Error::api(format!("SSE error: {e}"));
275                            return Some((Err(err), state));
276                        }
277                        None => {
278                            // Stream ended without message-end; surface a consistent error.
279                            return Some((
280                                Err(Error::api("Stream ended without Done event")),
281                                state,
282                            ));
283                        }
284                    }
285                }
286            },
287        );
288
289        Ok(Box::pin(stream))
290    }
291}
292
293// ============================================================================
294// Stream State
295// ============================================================================
296
297struct ToolCallAccum {
298    content_index: usize,
299    id: String,
300    name: String,
301    arguments: String,
302}
303
304struct StreamState<S>
305where
306    S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
307{
308    event_source: SseStream<S>,
309    partial: AssistantMessage,
310    pending_events: VecDeque<StreamEvent>,
311    started: bool,
312    finished: bool,
313    content_index_map: HashMap<u32, usize>,
314    active_tool_call: Option<ToolCallAccum>,
315    /// Consecutive WriteZero errors seen without a successful event in between.
316    transient_error_count: usize,
317}
318
319impl<S> StreamState<S>
320where
321    S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
322{
323    fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
324        Self {
325            event_source,
326            partial: AssistantMessage {
327                content: Vec::new(),
328                api,
329                provider,
330                model,
331                usage: Usage::default(),
332                stop_reason: StopReason::Stop,
333                error_message: None,
334                timestamp: chrono::Utc::now().timestamp_millis(),
335            },
336            pending_events: VecDeque::new(),
337            started: false,
338            finished: false,
339            content_index_map: HashMap::new(),
340            active_tool_call: None,
341            transient_error_count: 0,
342        }
343    }
344
345    fn ensure_started(&mut self) {
346        if !self.started {
347            self.started = true;
348            self.pending_events.push_back(StreamEvent::Start {
349                partial: self.partial.clone(),
350            });
351        }
352    }
353
354    fn content_block_for(&mut self, idx: u32, kind: CohereContentKind) -> usize {
355        if let Some(existing) = self.content_index_map.get(&idx) {
356            return *existing;
357        }
358
359        let content_index = self.partial.content.len();
360        match kind {
361            CohereContentKind::Text => {
362                self.partial
363                    .content
364                    .push(ContentBlock::Text(TextContent::new("")));
365                self.pending_events
366                    .push_back(StreamEvent::TextStart { content_index });
367            }
368            CohereContentKind::Thinking => {
369                self.partial
370                    .content
371                    .push(ContentBlock::Thinking(ThinkingContent {
372                        thinking: String::new(),
373                        thinking_signature: None,
374                    }));
375                self.pending_events
376                    .push_back(StreamEvent::ThinkingStart { content_index });
377            }
378        }
379
380        self.content_index_map.insert(idx, content_index);
381        content_index
382    }
383
384    #[allow(clippy::too_many_lines)]
385    fn process_event(&mut self, data: &str) -> Result<()> {
386        let chunk: CohereStreamChunk = serde_json::from_str(data)
387            .map_err(|e| Error::api(format!("JSON parse error: {e}\nData: {data}")))?;
388
389        match chunk {
390            CohereStreamChunk::MessageStart { .. } => {
391                self.ensure_started();
392            }
393            CohereStreamChunk::ContentStart { index, delta } => {
394                self.ensure_started();
395                let (kind, initial) = delta.message.content.kind_and_text();
396                let content_index = self.content_block_for(index, kind);
397
398                if !initial.is_empty() {
399                    match kind {
400                        CohereContentKind::Text => {
401                            if let Some(ContentBlock::Text(t)) =
402                                self.partial.content.get_mut(content_index)
403                            {
404                                t.text.push_str(&initial);
405                            }
406                            self.pending_events.push_back(StreamEvent::TextDelta {
407                                content_index,
408                                delta: initial,
409                            });
410                        }
411                        CohereContentKind::Thinking => {
412                            if let Some(ContentBlock::Thinking(t)) =
413                                self.partial.content.get_mut(content_index)
414                            {
415                                t.thinking.push_str(&initial);
416                            }
417                            self.pending_events.push_back(StreamEvent::ThinkingDelta {
418                                content_index,
419                                delta: initial,
420                            });
421                        }
422                    }
423                }
424            }
425            CohereStreamChunk::ContentDelta { index, delta } => {
426                self.ensure_started();
427                let (kind, delta_text) = delta.message.content.kind_and_text();
428                let content_index = self.content_block_for(index, kind);
429
430                match kind {
431                    CohereContentKind::Text => {
432                        if let Some(ContentBlock::Text(t)) =
433                            self.partial.content.get_mut(content_index)
434                        {
435                            t.text.push_str(&delta_text);
436                        }
437                        self.pending_events.push_back(StreamEvent::TextDelta {
438                            content_index,
439                            delta: delta_text,
440                        });
441                    }
442                    CohereContentKind::Thinking => {
443                        if let Some(ContentBlock::Thinking(t)) =
444                            self.partial.content.get_mut(content_index)
445                        {
446                            t.thinking.push_str(&delta_text);
447                        }
448                        self.pending_events.push_back(StreamEvent::ThinkingDelta {
449                            content_index,
450                            delta: delta_text,
451                        });
452                    }
453                }
454            }
455            CohereStreamChunk::ContentEnd { index } => {
456                if let Some(content_index) = self.content_index_map.get(&index).copied() {
457                    match self.partial.content.get(content_index) {
458                        Some(ContentBlock::Text(t)) => {
459                            self.pending_events.push_back(StreamEvent::TextEnd {
460                                content_index,
461                                content: t.text.clone(),
462                            });
463                        }
464                        Some(ContentBlock::Thinking(t)) => {
465                            self.pending_events.push_back(StreamEvent::ThinkingEnd {
466                                content_index,
467                                content: t.thinking.clone(),
468                            });
469                        }
470                        _ => {}
471                    }
472                }
473            }
474            CohereStreamChunk::ToolCallStart { delta } => {
475                self.ensure_started();
476                let tc = delta.message.tool_calls;
477                let content_index = self.partial.content.len();
478                self.partial.content.push(ContentBlock::ToolCall(ToolCall {
479                    id: tc.id.clone(),
480                    name: tc.function.name.clone(),
481                    arguments: serde_json::Value::Null,
482                    thought_signature: None,
483                }));
484
485                self.active_tool_call = Some(ToolCallAccum {
486                    content_index,
487                    id: tc.id,
488                    name: tc.function.name,
489                    arguments: tc.function.arguments.clone(),
490                });
491
492                self.pending_events
493                    .push_back(StreamEvent::ToolCallStart { content_index });
494                if !tc.function.arguments.is_empty() {
495                    self.pending_events.push_back(StreamEvent::ToolCallDelta {
496                        content_index,
497                        delta: tc.function.arguments,
498                    });
499                }
500            }
501            CohereStreamChunk::ToolCallDelta { delta } => {
502                self.ensure_started();
503                if let Some(active) = self.active_tool_call.as_mut() {
504                    active
505                        .arguments
506                        .push_str(&delta.message.tool_calls.function.arguments);
507                    self.pending_events.push_back(StreamEvent::ToolCallDelta {
508                        content_index: active.content_index,
509                        delta: delta.message.tool_calls.function.arguments,
510                    });
511                }
512            }
513            CohereStreamChunk::ToolCallEnd => {
514                if let Some(active) = self.active_tool_call.take() {
515                    self.ensure_started();
516                    let parsed_args: serde_json::Value = serde_json::from_str(&active.arguments)
517                        .unwrap_or_else(|e| {
518                            tracing::warn!(
519                                error = %e,
520                                raw = %active.arguments,
521                                "Failed to parse tool arguments as JSON"
522                            );
523                            serde_json::Value::Null
524                        });
525
526                    self.partial.stop_reason = StopReason::ToolUse;
527                    self.pending_events.push_back(StreamEvent::ToolCallEnd {
528                        content_index: active.content_index,
529                        tool_call: ToolCall {
530                            id: active.id,
531                            name: active.name,
532                            arguments: parsed_args.clone(),
533                            thought_signature: None,
534                        },
535                    });
536
537                    if let Some(ContentBlock::ToolCall(block)) =
538                        self.partial.content.get_mut(active.content_index)
539                    {
540                        block.arguments = parsed_args;
541                    }
542                }
543            }
544            CohereStreamChunk::MessageEnd { delta } => {
545                self.ensure_started();
546                self.partial.usage.input = delta.usage.tokens.input_tokens;
547                self.partial.usage.output = delta.usage.tokens.output_tokens;
548                self.partial.usage.total_tokens =
549                    delta.usage.tokens.input_tokens + delta.usage.tokens.output_tokens;
550
551                self.partial.stop_reason = match delta.finish_reason.as_str() {
552                    "MAX_TOKENS" => StopReason::Length,
553                    "TOOL_CALL" => StopReason::ToolUse,
554                    "ERROR" => StopReason::Error,
555                    _ => StopReason::Stop,
556                };
557
558                self.finish();
559            }
560            CohereStreamChunk::Unknown => {}
561        }
562
563        Ok(())
564    }
565
566    fn finish(&mut self) {
567        if self.finished {
568            return;
569        }
570        let reason = self.partial.stop_reason;
571        self.pending_events.push_back(StreamEvent::Done {
572            reason,
573            message: std::mem::take(&mut self.partial),
574        });
575        self.finished = true;
576    }
577}
578
579// ============================================================================
580// Cohere API Types (minimal)
581// ============================================================================
582
583#[derive(Debug, Serialize)]
584pub struct CohereRequest {
585    model: String,
586    messages: Vec<CohereMessage>,
587    #[serde(skip_serializing_if = "Option::is_none")]
588    max_tokens: Option<u32>,
589    #[serde(skip_serializing_if = "Option::is_none")]
590    temperature: Option<f32>,
591    #[serde(skip_serializing_if = "Option::is_none")]
592    tools: Option<Vec<CohereTool>>,
593    stream: bool,
594}
595
596#[derive(Debug, Serialize)]
597#[serde(tag = "role", rename_all = "lowercase")]
598enum CohereMessage {
599    System {
600        content: String,
601    },
602    User {
603        content: String,
604    },
605    Assistant {
606        #[serde(skip_serializing_if = "Option::is_none")]
607        content: Option<String>,
608        #[serde(skip_serializing_if = "Option::is_none")]
609        tool_calls: Option<Vec<CohereToolCallRef>>,
610        #[serde(skip_serializing_if = "Option::is_none")]
611        tool_plan: Option<String>,
612    },
613    Tool {
614        content: String,
615        tool_call_id: String,
616    },
617}
618
619#[derive(Debug, Serialize)]
620struct CohereToolCallRef {
621    id: String,
622    #[serde(rename = "type")]
623    r#type: &'static str,
624    function: CohereFunctionRef,
625}
626
627#[derive(Debug, Serialize)]
628struct CohereFunctionRef {
629    name: String,
630    arguments: String,
631}
632
633#[derive(Debug, Serialize)]
634struct CohereTool {
635    #[serde(rename = "type")]
636    r#type: &'static str,
637    function: CohereFunction,
638}
639
640#[derive(Debug, Serialize)]
641struct CohereFunction {
642    name: String,
643    #[serde(skip_serializing_if = "Option::is_none")]
644    description: Option<String>,
645    parameters: serde_json::Value,
646}
647
648fn convert_tool_to_cohere(tool: &ToolDef) -> CohereTool {
649    CohereTool {
650        r#type: "function",
651        function: CohereFunction {
652            name: tool.name.clone(),
653            description: if tool.description.trim().is_empty() {
654                None
655            } else {
656                Some(tool.description.clone())
657            },
658            parameters: tool.parameters.clone(),
659        },
660    }
661}
662
663fn build_cohere_messages(context: &Context<'_>) -> Vec<CohereMessage> {
664    let mut out = Vec::new();
665
666    if let Some(system) = &context.system_prompt {
667        out.push(CohereMessage::System {
668            content: system.to_string(),
669        });
670    }
671
672    for message in context.messages.iter() {
673        match message {
674            Message::User(user) => out.push(CohereMessage::User {
675                content: extract_text_user_content(&user.content),
676            }),
677            Message::Custom(custom) => out.push(CohereMessage::User {
678                content: custom.content.clone(),
679            }),
680            Message::Assistant(assistant) => {
681                let mut text = String::new();
682                let mut tool_calls = Vec::new();
683
684                for block in &assistant.content {
685                    match block {
686                        ContentBlock::Text(t) => text.push_str(&t.text),
687                        ContentBlock::ToolCall(tc) => tool_calls.push(CohereToolCallRef {
688                            id: tc.id.clone(),
689                            r#type: "function",
690                            function: CohereFunctionRef {
691                                name: tc.name.clone(),
692                                arguments: tc.arguments.to_string(),
693                            },
694                        }),
695                        _ => {}
696                    }
697                }
698
699                out.push(CohereMessage::Assistant {
700                    content: if text.is_empty() { None } else { Some(text) },
701                    tool_calls: if tool_calls.is_empty() {
702                        None
703                    } else {
704                        Some(tool_calls)
705                    },
706                    tool_plan: None,
707                });
708            }
709            Message::ToolResult(result) => {
710                let mut content = String::new();
711                for (i, block) in result.content.iter().enumerate() {
712                    if i > 0 {
713                        content.push('\n');
714                    }
715                    if let ContentBlock::Text(t) = block {
716                        content.push_str(&t.text);
717                    }
718                }
719                out.push(CohereMessage::Tool {
720                    content,
721                    tool_call_id: result.tool_call_id.clone(),
722                });
723            }
724        }
725    }
726
727    out
728}
729
730fn extract_text_user_content(content: &UserContent) -> String {
731    match content {
732        UserContent::Text(text) => text.clone(),
733        UserContent::Blocks(blocks) => {
734            let mut out = String::new();
735            for block in blocks {
736                match block {
737                    ContentBlock::Text(t) => out.push_str(&t.text),
738                    ContentBlock::Image(img) => {
739                        use std::fmt::Write as _;
740                        let _ =
741                            write!(out, "[Image: {} ({} bytes)]", img.mime_type, img.data.len());
742                    }
743                    _ => {}
744                }
745            }
746            out
747        }
748    }
749}
750
751// ============================================================================
752// Cohere streaming chunk types (minimal, forward-compatible)
753// ============================================================================
754
755#[derive(Debug, Deserialize)]
756#[serde(tag = "type")]
757enum CohereStreamChunk {
758    #[serde(rename = "message-start")]
759    MessageStart { id: Option<String> },
760    #[serde(rename = "content-start")]
761    ContentStart {
762        index: u32,
763        delta: CohereContentStartDelta,
764    },
765    #[serde(rename = "content-delta")]
766    ContentDelta {
767        index: u32,
768        delta: CohereContentDelta,
769    },
770    #[serde(rename = "content-end")]
771    ContentEnd { index: u32 },
772    #[serde(rename = "tool-call-start")]
773    ToolCallStart { delta: CohereToolCallStartDelta },
774    #[serde(rename = "tool-call-delta")]
775    ToolCallDelta { delta: CohereToolCallDelta },
776    #[serde(rename = "tool-call-end")]
777    ToolCallEnd,
778    #[serde(rename = "message-end")]
779    MessageEnd { delta: CohereMessageEndDelta },
780    #[serde(other)]
781    Unknown,
782}
783
784#[derive(Debug, Deserialize)]
785struct CohereContentStartDelta {
786    message: CohereDeltaMessage<CohereContentStart>,
787}
788
789#[derive(Debug, Deserialize)]
790struct CohereContentDelta {
791    message: CohereDeltaMessage<CohereContentDeltaPart>,
792}
793
794#[derive(Debug, Deserialize)]
795struct CohereDeltaMessage<T> {
796    content: T,
797}
798
799#[derive(Debug, Deserialize)]
800#[serde(tag = "type")]
801enum CohereContentStart {
802    #[serde(rename = "text")]
803    Text { text: String },
804    #[serde(rename = "thinking")]
805    Thinking { thinking: String },
806}
807
808#[derive(Debug, Deserialize)]
809#[serde(untagged)]
810enum CohereContentDeltaPart {
811    Text { text: String },
812    Thinking { thinking: String },
813}
814
815#[derive(Debug, Clone, Copy)]
816enum CohereContentKind {
817    Text,
818    Thinking,
819}
820
821impl CohereContentStart {
822    fn kind_and_text(self) -> (CohereContentKind, String) {
823        match self {
824            Self::Text { text } => (CohereContentKind::Text, text),
825            Self::Thinking { thinking } => (CohereContentKind::Thinking, thinking),
826        }
827    }
828}
829
830impl CohereContentDeltaPart {
831    fn kind_and_text(self) -> (CohereContentKind, String) {
832        match self {
833            Self::Text { text } => (CohereContentKind::Text, text),
834            Self::Thinking { thinking } => (CohereContentKind::Thinking, thinking),
835        }
836    }
837}
838
839#[derive(Debug, Deserialize)]
840struct CohereToolCallStartDelta {
841    message: CohereToolCallMessage<CohereToolCallStartBody>,
842}
843
844#[derive(Debug, Deserialize)]
845struct CohereToolCallDelta {
846    message: CohereToolCallMessage<CohereToolCallDeltaBody>,
847}
848
849#[derive(Debug, Deserialize)]
850struct CohereToolCallMessage<T> {
851    tool_calls: T,
852}
853
854#[derive(Debug, Deserialize)]
855struct CohereToolCallStartBody {
856    id: String,
857    function: CohereToolCallFunctionStart,
858}
859
860#[derive(Debug, Deserialize)]
861struct CohereToolCallFunctionStart {
862    name: String,
863    arguments: String,
864}
865
866#[derive(Debug, Deserialize)]
867struct CohereToolCallDeltaBody {
868    function: CohereToolCallFunctionDelta,
869}
870
871#[derive(Debug, Deserialize)]
872struct CohereToolCallFunctionDelta {
873    arguments: String,
874}
875
876#[derive(Debug, Deserialize)]
877struct CohereMessageEndDelta {
878    finish_reason: String,
879    usage: CohereUsage,
880}
881
882#[derive(Debug, Deserialize)]
883struct CohereUsage {
884    tokens: CohereUsageTokens,
885}
886
887#[derive(Debug, Deserialize)]
888struct CohereUsageTokens {
889    input_tokens: u64,
890    output_tokens: u64,
891}
892
893// ============================================================================
894// Tests
895// ============================================================================
896
897#[cfg(test)]
898mod tests {
899    use super::*;
900    use asupersync::runtime::RuntimeBuilder;
901    use futures::stream;
902    use serde_json::{Value, json};
903    use std::collections::HashMap;
904    use std::io::{Read, Write};
905    use std::net::TcpListener;
906    use std::path::PathBuf;
907    use std::sync::mpsc;
908    use std::time::Duration;
909
910    // ─── Fixture infrastructure ─────────────────────────────────────────
911
912    #[derive(Debug, Deserialize)]
913    struct ProviderFixture {
914        cases: Vec<ProviderCase>,
915    }
916
917    #[derive(Debug, Deserialize)]
918    struct ProviderCase {
919        name: String,
920        events: Vec<Value>,
921        expected: Vec<EventSummary>,
922    }
923
924    #[derive(Debug, Deserialize, Serialize, PartialEq)]
925    struct EventSummary {
926        kind: String,
927        #[serde(default)]
928        content_index: Option<usize>,
929        #[serde(default)]
930        delta: Option<String>,
931        #[serde(default)]
932        content: Option<String>,
933        #[serde(default)]
934        reason: Option<String>,
935    }
936
937    fn load_fixture(file_name: &str) -> ProviderFixture {
938        let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
939            .join("tests/fixtures/provider_responses")
940            .join(file_name);
941        let raw = std::fs::read_to_string(path).expect("fixture read");
942        serde_json::from_str(&raw).expect("fixture parse")
943    }
944
945    fn summarize_event(event: &StreamEvent) -> EventSummary {
946        match event {
947            StreamEvent::Start { .. } => EventSummary {
948                kind: "start".to_string(),
949                content_index: None,
950                delta: None,
951                content: None,
952                reason: None,
953            },
954            StreamEvent::TextStart { content_index, .. } => EventSummary {
955                kind: "text_start".to_string(),
956                content_index: Some(*content_index),
957                delta: None,
958                content: None,
959                reason: None,
960            },
961            StreamEvent::TextDelta {
962                content_index,
963                delta,
964                ..
965            } => EventSummary {
966                kind: "text_delta".to_string(),
967                content_index: Some(*content_index),
968                delta: Some(delta.clone()),
969                content: None,
970                reason: None,
971            },
972            StreamEvent::TextEnd {
973                content_index,
974                content,
975                ..
976            } => EventSummary {
977                kind: "text_end".to_string(),
978                content_index: Some(*content_index),
979                delta: None,
980                content: Some(content.clone()),
981                reason: None,
982            },
983            StreamEvent::Done { reason, .. } => EventSummary {
984                kind: "done".to_string(),
985                content_index: None,
986                delta: None,
987                content: None,
988                reason: Some(reason_to_string(*reason)),
989            },
990            StreamEvent::Error { reason, .. } => EventSummary {
991                kind: "error".to_string(),
992                content_index: None,
993                delta: None,
994                content: None,
995                reason: Some(reason_to_string(*reason)),
996            },
997            _ => EventSummary {
998                kind: "other".to_string(),
999                content_index: None,
1000                delta: None,
1001                content: None,
1002                reason: None,
1003            },
1004        }
1005    }
1006
1007    fn reason_to_string(reason: StopReason) -> String {
1008        match reason {
1009            StopReason::Stop => "stop",
1010            StopReason::Length => "length",
1011            StopReason::ToolUse => "tool_use",
1012            StopReason::Error => "error",
1013            StopReason::Aborted => "aborted",
1014        }
1015        .to_string()
1016    }
1017
1018    #[test]
1019    fn test_stream_fixtures() {
1020        let fixture = load_fixture("cohere_stream.json");
1021        for case in fixture.cases {
1022            let events = collect_events(&case.events);
1023            let summaries: Vec<EventSummary> = events.iter().map(summarize_event).collect();
1024            assert_eq!(summaries, case.expected, "case {}", case.name);
1025        }
1026    }
1027
1028    // ─── Existing tests ─────────────────────────────────────────────────
1029
1030    #[test]
1031    fn test_provider_info() {
1032        let provider = CohereProvider::new("command-r");
1033        assert_eq!(provider.name(), "cohere");
1034        assert_eq!(provider.api(), "cohere-chat");
1035    }
1036
1037    #[test]
1038    fn test_build_request_includes_system_tools_and_v2_shape() {
1039        let provider = CohereProvider::new("command-r");
1040        let context = Context::owned(
1041            Some("You are concise.".to_string()),
1042            vec![Message::User(crate::model::UserMessage {
1043                content: UserContent::Text("Ping".to_string()),
1044                timestamp: 0,
1045            })],
1046            vec![ToolDef {
1047                name: "search".to_string(),
1048                description: "Search docs".to_string(),
1049                parameters: json!({
1050                    "type": "object",
1051                    "properties": {
1052                        "q": { "type": "string" }
1053                    },
1054                    "required": ["q"]
1055                }),
1056            }],
1057        );
1058        let options = StreamOptions {
1059            temperature: Some(0.2),
1060            max_tokens: Some(123),
1061            ..Default::default()
1062        };
1063
1064        let request = provider.build_request(&context, &options);
1065        let value = serde_json::to_value(&request).expect("serialize request");
1066
1067        assert_eq!(value["model"], "command-r");
1068        assert_eq!(value["messages"][0]["role"], "system");
1069        assert_eq!(value["messages"][0]["content"], "You are concise.");
1070        assert_eq!(value["messages"][1]["role"], "user");
1071        assert_eq!(value["messages"][1]["content"], "Ping");
1072        assert_eq!(value["stream"], true);
1073        assert_eq!(value["max_tokens"], 123);
1074        let temperature = value["temperature"]
1075            .as_f64()
1076            .expect("temperature should be numeric");
1077        assert!((temperature - 0.2).abs() < 1e-6);
1078        assert_eq!(value["tools"][0]["type"], "function");
1079        assert_eq!(value["tools"][0]["function"]["name"], "search");
1080        assert_eq!(value["tools"][0]["function"]["description"], "Search docs");
1081        assert_eq!(
1082            value["tools"][0]["function"]["parameters"],
1083            json!({
1084                "type": "object",
1085                "properties": {
1086                    "q": { "type": "string" }
1087                },
1088                "required": ["q"]
1089            })
1090        );
1091    }
1092
1093    #[test]
1094    fn test_convert_tool_to_cohere_omits_empty_description() {
1095        let tool = ToolDef {
1096            name: "echo".to_string(),
1097            description: "   ".to_string(),
1098            parameters: json!({
1099                "type": "object",
1100                "properties": {
1101                    "text": { "type": "string" }
1102                }
1103            }),
1104        };
1105
1106        let converted = convert_tool_to_cohere(&tool);
1107        let value = serde_json::to_value(converted).expect("serialize converted tool");
1108        assert_eq!(value["type"], "function");
1109        assert_eq!(value["function"]["name"], "echo");
1110        assert!(value["function"].get("description").is_none());
1111    }
1112
1113    #[test]
1114    fn test_stream_parses_text_and_tool_call() {
1115        let runtime = RuntimeBuilder::current_thread()
1116            .build()
1117            .expect("runtime build");
1118
1119        runtime.block_on(async move {
1120            let events = [
1121                serde_json::json!({ "type": "message-start", "id": "msg_1" }),
1122                serde_json::json!({
1123                    "type": "content-start",
1124                    "index": 0,
1125                    "delta": { "message": { "content": { "type": "text", "text": "Hello" } } }
1126                }),
1127                serde_json::json!({
1128                    "type": "content-delta",
1129                    "index": 0,
1130                    "delta": { "message": { "content": { "text": " world" } } }
1131                }),
1132                serde_json::json!({ "type": "content-end", "index": 0 }),
1133                serde_json::json!({
1134                    "type": "tool-call-start",
1135                    "delta": { "message": { "tool_calls": { "id": "call_1", "type": "function", "function": { "name": "echo", "arguments": "{\"text\":\"hi\"}" } } } }
1136                }),
1137                serde_json::json!({ "type": "tool-call-end" }),
1138                serde_json::json!({
1139                    "type": "message-end",
1140                    "delta": { "finish_reason": "TOOL_CALL", "usage": { "tokens": { "input_tokens": 1, "output_tokens": 2 } } }
1141                }),
1142            ];
1143
1144            let byte_stream = stream::iter(
1145                events
1146                    .iter()
1147                    .map(|event| format!("data: {}\n\n", serde_json::to_string(event).unwrap()))
1148                    .map(|s| Ok(s.into_bytes())),
1149            );
1150
1151            let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1152            let mut state = StreamState::new(
1153                event_source,
1154                "command-r".to_string(),
1155                "cohere-chat".to_string(),
1156                "cohere".to_string(),
1157            );
1158
1159            let mut out = Vec::new();
1160            while let Some(item) = state.event_source.next().await {
1161                let msg = item.expect("SSE event");
1162                state.process_event(&msg.data).expect("process_event");
1163                out.extend(state.pending_events.drain(..));
1164                if state.finished {
1165                    break;
1166                }
1167            }
1168
1169            assert!(matches!(out.first(), Some(StreamEvent::Start { .. })));
1170            assert!(out.iter().any(|e| matches!(e, StreamEvent::TextDelta { delta, .. } if delta.contains("Hello"))));
1171            assert!(out.iter().any(|e| matches!(e, StreamEvent::ToolCallEnd { tool_call, .. } if tool_call.name == "echo")));
1172            assert!(out.iter().any(|e| matches!(e, StreamEvent::Done { reason: StopReason::ToolUse, .. })));
1173        });
1174    }
1175
1176    #[test]
1177    fn test_stream_parses_thinking_and_max_tokens_stop_reason() {
1178        let events = vec![
1179            json!({ "type": "message-start", "id": "msg_1" }),
1180            json!({
1181                "type": "content-start",
1182                "index": 0,
1183                "delta": { "message": { "content": { "type": "thinking", "thinking": "Plan" } } }
1184            }),
1185            json!({
1186                "type": "content-delta",
1187                "index": 0,
1188                "delta": { "message": { "content": { "thinking": " more" } } }
1189            }),
1190            json!({ "type": "content-end", "index": 0 }),
1191            json!({
1192                "type": "message-end",
1193                "delta": { "finish_reason": "MAX_TOKENS", "usage": { "tokens": { "input_tokens": 2, "output_tokens": 3 } } }
1194            }),
1195        ];
1196
1197        let out = collect_events(&events);
1198        assert!(
1199            out.iter()
1200                .any(|e| matches!(e, StreamEvent::ThinkingStart { .. }))
1201        );
1202        assert!(out.iter().any(
1203            |e| matches!(e, StreamEvent::ThinkingDelta { delta, .. } if delta.contains("Plan"))
1204        ));
1205        assert!(
1206            out.iter()
1207                .any(|e| matches!(e, StreamEvent::ThinkingEnd { content, .. } if content.contains("Plan more")))
1208        );
1209        assert!(out.iter().any(|e| matches!(
1210            e,
1211            StreamEvent::Done {
1212                reason: StopReason::Length,
1213                ..
1214            }
1215        )));
1216    }
1217
1218    #[test]
1219    fn test_stream_sets_bearer_auth_header() {
1220        let captured = run_stream_and_capture_headers(Some("test-cohere-key"), HashMap::new())
1221            .expect("captured request");
1222        assert_eq!(
1223            captured.headers.get("authorization").map(String::as_str),
1224            Some("Bearer test-cohere-key")
1225        );
1226        assert_eq!(
1227            captured.headers.get("accept").map(String::as_str),
1228            Some("text/event-stream")
1229        );
1230
1231        let body: Value = serde_json::from_str(&captured.body).expect("body json");
1232        assert_eq!(body["model"], "command-r");
1233        assert_eq!(body["stream"], true);
1234    }
1235
1236    #[test]
1237    fn test_stream_uses_existing_authorization_header_without_api_key() {
1238        let mut headers = HashMap::new();
1239        headers.insert(
1240            "Authorization".to_string(),
1241            "Bearer from-custom-header".to_string(),
1242        );
1243        headers.insert("X-Test".to_string(), "1".to_string());
1244
1245        let captured = run_stream_and_capture_headers(None, headers).expect("captured request");
1246        assert_eq!(
1247            captured.headers.get("authorization").map(String::as_str),
1248            Some("Bearer from-custom-header")
1249        );
1250        assert_eq!(
1251            captured.headers.get("x-test").map(String::as_str),
1252            Some("1")
1253        );
1254    }
1255
1256    #[test]
1257    fn test_stream_compat_authorization_header_overrides_api_key_without_duplicate() {
1258        let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1259        let mut custom_headers = HashMap::new();
1260        custom_headers.insert(
1261            "Authorization".to_string(),
1262            "Bearer compat-header".to_string(),
1263        );
1264        let provider = CohereProvider::new("command-r")
1265            .with_base_url(base_url)
1266            .with_compat(Some(CompatConfig {
1267                custom_headers: Some(custom_headers),
1268                ..Default::default()
1269            }));
1270        let context = Context::owned(
1271            Some("system".to_string()),
1272            vec![Message::User(crate::model::UserMessage {
1273                content: UserContent::Text("ping".to_string()),
1274                timestamp: 0,
1275            })],
1276            Vec::new(),
1277        );
1278        let options = StreamOptions {
1279            api_key: Some("test-cohere-key".to_string()),
1280            ..Default::default()
1281        };
1282
1283        let runtime = RuntimeBuilder::current_thread()
1284            .build()
1285            .expect("runtime build");
1286        runtime.block_on(async {
1287            let mut stream = provider.stream(&context, &options).await.expect("stream");
1288            while let Some(event) = stream.next().await {
1289                if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1290                    break;
1291                }
1292            }
1293        });
1294
1295        let captured = rx.recv_timeout(Duration::from_secs(2)).expect("captured");
1296        assert_eq!(
1297            captured.headers.get("authorization").map(String::as_str),
1298            Some("Bearer compat-header")
1299        );
1300        assert_eq!(captured.header_count("authorization"), 1);
1301    }
1302
1303    #[test]
1304    fn test_stream_compat_authorization_header_works_without_api_key() {
1305        let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1306        let mut custom_headers = HashMap::new();
1307        custom_headers.insert(
1308            "Authorization".to_string(),
1309            "Bearer compat-header".to_string(),
1310        );
1311        let provider = CohereProvider::new("command-r")
1312            .with_base_url(base_url)
1313            .with_compat(Some(CompatConfig {
1314                custom_headers: Some(custom_headers),
1315                ..Default::default()
1316            }));
1317        let context = Context::owned(
1318            Some("system".to_string()),
1319            vec![Message::User(crate::model::UserMessage {
1320                content: UserContent::Text("ping".to_string()),
1321                timestamp: 0,
1322            })],
1323            Vec::new(),
1324        );
1325
1326        let runtime = RuntimeBuilder::current_thread()
1327            .build()
1328            .expect("runtime build");
1329        runtime.block_on(async {
1330            let mut stream = provider
1331                .stream(&context, &StreamOptions::default())
1332                .await
1333                .expect("stream");
1334            while let Some(event) = stream.next().await {
1335                if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1336                    break;
1337                }
1338            }
1339        });
1340
1341        let captured = rx.recv_timeout(Duration::from_secs(2)).expect("captured");
1342        assert_eq!(
1343            captured.headers.get("authorization").map(String::as_str),
1344            Some("Bearer compat-header")
1345        );
1346        assert_eq!(captured.header_count("authorization"), 1);
1347    }
1348
1349    fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
1350        let runtime = RuntimeBuilder::current_thread()
1351            .build()
1352            .expect("runtime build");
1353
1354        runtime.block_on(async {
1355            let byte_stream = stream::iter(
1356                events
1357                    .iter()
1358                    .map(|event| {
1359                        format!(
1360                            "data: {}\n\n",
1361                            serde_json::to_string(event).expect("serialize event")
1362                        )
1363                    })
1364                    .map(|s| Ok(s.into_bytes())),
1365            );
1366
1367            let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1368            let mut state = StreamState::new(
1369                event_source,
1370                "command-r".to_string(),
1371                "cohere-chat".to_string(),
1372                "cohere".to_string(),
1373            );
1374
1375            let mut out = Vec::new();
1376            while let Some(item) = state.event_source.next().await {
1377                let msg = item.expect("SSE event");
1378                state.process_event(&msg.data).expect("process_event");
1379                out.extend(state.pending_events.drain(..));
1380                if state.finished {
1381                    break;
1382                }
1383            }
1384            out
1385        })
1386    }
1387
1388    #[derive(Debug)]
1389    struct CapturedRequest {
1390        headers: HashMap<String, String>,
1391        header_lines: Vec<(String, String)>,
1392        body: String,
1393    }
1394
1395    impl CapturedRequest {
1396        fn header_count(&self, name: &str) -> usize {
1397            self.header_lines
1398                .iter()
1399                .filter(|(key, _)| key.eq_ignore_ascii_case(name))
1400                .count()
1401        }
1402    }
1403
1404    fn run_stream_and_capture_headers(
1405        api_key: Option<&str>,
1406        extra_headers: HashMap<String, String>,
1407    ) -> Option<CapturedRequest> {
1408        let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1409        let provider = CohereProvider::new("command-r").with_base_url(base_url);
1410        let context = Context::owned(
1411            Some("system".to_string()),
1412            vec![Message::User(crate::model::UserMessage {
1413                content: UserContent::Text("ping".to_string()),
1414                timestamp: 0,
1415            })],
1416            Vec::new(),
1417        );
1418        let options = StreamOptions {
1419            api_key: api_key.map(str::to_string),
1420            headers: extra_headers,
1421            ..Default::default()
1422        };
1423
1424        let runtime = RuntimeBuilder::current_thread()
1425            .build()
1426            .expect("runtime build");
1427        runtime.block_on(async {
1428            let mut stream = provider.stream(&context, &options).await.expect("stream");
1429            while let Some(event) = stream.next().await {
1430                if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1431                    break;
1432                }
1433            }
1434        });
1435
1436        rx.recv_timeout(Duration::from_secs(2)).ok()
1437    }
1438
1439    fn success_sse_body() -> String {
1440        [
1441            r#"data: {"type":"message-start","id":"msg_1"}"#,
1442            "",
1443            r#"data: {"type":"message-end","delta":{"finish_reason":"COMPLETE","usage":{"tokens":{"input_tokens":1,"output_tokens":1}}}}"#,
1444            "",
1445        ]
1446        .join("\n")
1447    }
1448
1449    fn spawn_test_server(
1450        status_code: u16,
1451        content_type: &str,
1452        body: &str,
1453    ) -> (String, mpsc::Receiver<CapturedRequest>) {
1454        let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
1455        let addr = listener.local_addr().expect("local addr");
1456        let (tx, rx) = mpsc::channel();
1457        let body = body.to_string();
1458        let content_type = content_type.to_string();
1459
1460        std::thread::spawn(move || {
1461            let (mut socket, _) = listener.accept().expect("accept");
1462            socket
1463                .set_read_timeout(Some(Duration::from_secs(2)))
1464                .expect("set read timeout");
1465
1466            let mut bytes = Vec::new();
1467            let mut chunk = [0_u8; 4096];
1468            loop {
1469                match socket.read(&mut chunk) {
1470                    Ok(0) => break,
1471                    Ok(n) => {
1472                        bytes.extend_from_slice(&chunk[..n]);
1473                        if bytes.windows(4).any(|window| window == b"\r\n\r\n") {
1474                            break;
1475                        }
1476                    }
1477                    Err(err)
1478                        if err.kind() == std::io::ErrorKind::WouldBlock
1479                            || err.kind() == std::io::ErrorKind::TimedOut =>
1480                    {
1481                        break;
1482                    }
1483                    Err(err) => panic!(),
1484                }
1485            }
1486
1487            let header_end = bytes
1488                .windows(4)
1489                .position(|window| window == b"\r\n\r\n")
1490                .expect("request header boundary");
1491            let header_text = String::from_utf8_lossy(&bytes[..header_end]).to_string();
1492            let (headers, header_lines) = parse_headers(&header_text);
1493            let mut request_body = bytes[header_end + 4..].to_vec();
1494
1495            let content_length = headers
1496                .get("content-length")
1497                .and_then(|value| value.parse::<usize>().ok())
1498                .unwrap_or(0);
1499            while request_body.len() < content_length {
1500                match socket.read(&mut chunk) {
1501                    Ok(0) => break,
1502                    Ok(n) => request_body.extend_from_slice(&chunk[..n]),
1503                    Err(err)
1504                        if err.kind() == std::io::ErrorKind::WouldBlock
1505                            || err.kind() == std::io::ErrorKind::TimedOut =>
1506                    {
1507                        break;
1508                    }
1509                    Err(err) => panic!(),
1510                }
1511            }
1512
1513            let captured = CapturedRequest {
1514                headers,
1515                header_lines,
1516                body: String::from_utf8_lossy(&request_body).to_string(),
1517            };
1518            tx.send(captured).expect("send captured request");
1519
1520            let reason = match status_code {
1521                401 => "Unauthorized",
1522                500 => "Internal Server Error",
1523                _ => "OK",
1524            };
1525            let response = format!(
1526                "HTTP/1.1 {status_code} {reason}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
1527                body.len()
1528            );
1529            socket
1530                .write_all(response.as_bytes())
1531                .expect("write response");
1532            socket.flush().expect("flush response");
1533        });
1534
1535        (format!("http://{addr}"), rx)
1536    }
1537
1538    fn parse_headers(header_text: &str) -> (HashMap<String, String>, Vec<(String, String)>) {
1539        let mut headers = HashMap::new();
1540        let mut header_lines = Vec::new();
1541        for line in header_text.lines().skip(1) {
1542            if let Some((name, value)) = line.split_once(':') {
1543                let normalized_name = name.trim().to_ascii_lowercase();
1544                let normalized_value = value.trim().to_string();
1545                header_lines.push((normalized_name.clone(), normalized_value.clone()));
1546                headers.insert(normalized_name, normalized_value);
1547            }
1548        }
1549        (headers, header_lines)
1550    }
1551
1552    // ─── Request body format tests ──────────────────────────────────────
1553
1554    #[test]
1555    fn test_build_request_no_system_prompt() {
1556        let provider = CohereProvider::new("command-r-plus");
1557        let context = Context::owned(
1558            None,
1559            vec![Message::User(crate::model::UserMessage {
1560                content: UserContent::Text("Hi".to_string()),
1561                timestamp: 0,
1562            })],
1563            vec![],
1564        );
1565        let options = StreamOptions::default();
1566
1567        let req = provider.build_request(&context, &options);
1568        let value = serde_json::to_value(&req).expect("serialize");
1569
1570        // First message should be user, no system message.
1571        assert_eq!(value["messages"][0]["role"], "user");
1572        assert_eq!(value["messages"][0]["content"], "Hi");
1573        // No system role message at all.
1574        let msgs = value["messages"].as_array().unwrap();
1575        assert!(
1576            !msgs.iter().any(|m| m["role"] == "system"),
1577            "No system message should be present"
1578        );
1579    }
1580
1581    #[test]
1582    fn test_build_request_default_max_tokens() {
1583        let provider = CohereProvider::new("command-r");
1584        let context = Context::owned(
1585            None,
1586            vec![Message::User(crate::model::UserMessage {
1587                content: UserContent::Text("test".to_string()),
1588                timestamp: 0,
1589            })],
1590            vec![],
1591        );
1592        let options = StreamOptions::default();
1593
1594        let req = provider.build_request(&context, &options);
1595        let value = serde_json::to_value(&req).expect("serialize");
1596
1597        assert_eq!(value["max_tokens"], DEFAULT_MAX_TOKENS);
1598    }
1599
1600    #[test]
1601    fn test_build_request_no_tools_omits_tools_field() {
1602        let provider = CohereProvider::new("command-r");
1603        let context = Context::owned(
1604            None,
1605            vec![Message::User(crate::model::UserMessage {
1606                content: UserContent::Text("test".to_string()),
1607                timestamp: 0,
1608            })],
1609            vec![],
1610        );
1611        let options = StreamOptions::default();
1612
1613        let req = provider.build_request(&context, &options);
1614        let value = serde_json::to_value(&req).expect("serialize");
1615
1616        assert!(
1617            value.get("tools").is_none() || value["tools"].is_null(),
1618            "tools field should be omitted when empty"
1619        );
1620    }
1621
1622    #[test]
1623    fn test_build_request_full_conversation_with_tool_call_and_result() {
1624        let provider = CohereProvider::new("command-r");
1625        let context = Context::owned(
1626            Some("Be concise.".to_string()),
1627            vec![
1628                Message::User(crate::model::UserMessage {
1629                    content: UserContent::Text("Read /tmp/a.txt".to_string()),
1630                    timestamp: 0,
1631                }),
1632                Message::assistant(AssistantMessage {
1633                    content: vec![ContentBlock::ToolCall(ToolCall {
1634                        id: "call_1".to_string(),
1635                        name: "read".to_string(),
1636                        arguments: serde_json::json!({"path": "/tmp/a.txt"}),
1637                        thought_signature: None,
1638                    })],
1639                    api: "cohere-chat".to_string(),
1640                    provider: "cohere".to_string(),
1641                    model: "command-r".to_string(),
1642                    usage: Usage::default(),
1643                    stop_reason: StopReason::ToolUse,
1644                    error_message: None,
1645                    timestamp: 1,
1646                }),
1647                Message::tool_result(crate::model::ToolResultMessage {
1648                    tool_call_id: "call_1".to_string(),
1649                    tool_name: "read".to_string(),
1650                    content: vec![ContentBlock::Text(TextContent::new("file contents"))],
1651                    details: None,
1652                    is_error: false,
1653                    timestamp: 2,
1654                }),
1655            ],
1656            vec![ToolDef {
1657                name: "read".to_string(),
1658                description: "Read a file".to_string(),
1659                parameters: json!({"type": "object"}),
1660            }],
1661        );
1662        let options = StreamOptions::default();
1663
1664        let req = provider.build_request(&context, &options);
1665        let value = serde_json::to_value(&req).expect("serialize");
1666
1667        let msgs = value["messages"].as_array().unwrap();
1668        // system, user, assistant, tool
1669        assert_eq!(msgs.len(), 4);
1670        assert_eq!(msgs[0]["role"], "system");
1671        assert_eq!(msgs[1]["role"], "user");
1672        assert_eq!(msgs[2]["role"], "assistant");
1673        assert_eq!(msgs[3]["role"], "tool");
1674
1675        // Assistant message should have tool_calls, not content text.
1676        assert!(msgs[2].get("content").is_none() || msgs[2]["content"].is_null());
1677        let tool_calls = msgs[2]["tool_calls"].as_array().unwrap();
1678        assert_eq!(tool_calls.len(), 1);
1679        assert_eq!(tool_calls[0]["id"], "call_1");
1680        assert_eq!(tool_calls[0]["type"], "function");
1681        assert_eq!(tool_calls[0]["function"]["name"], "read");
1682
1683        // Tool result should reference the tool_call_id.
1684        assert_eq!(msgs[3]["tool_call_id"], "call_1");
1685        assert_eq!(msgs[3]["content"], "file contents");
1686    }
1687
1688    #[test]
1689    fn test_build_request_assistant_text_preserved_alongside_tool_calls() {
1690        let provider = CohereProvider::new("command-r");
1691        let context = Context::owned(
1692            None,
1693            vec![Message::assistant(AssistantMessage {
1694                content: vec![
1695                    ContentBlock::Text(TextContent::new("Let me read that file.")),
1696                    ContentBlock::ToolCall(ToolCall {
1697                        id: "call_1".to_string(),
1698                        name: "read".to_string(),
1699                        arguments: json!({"path": "/tmp/a.txt"}),
1700                        thought_signature: None,
1701                    }),
1702                ],
1703                api: "cohere-chat".to_string(),
1704                provider: "cohere".to_string(),
1705                model: "command-r".to_string(),
1706                usage: Usage::default(),
1707                stop_reason: StopReason::ToolUse,
1708                error_message: None,
1709                timestamp: 0,
1710            })],
1711            vec![],
1712        );
1713        let options = StreamOptions::default();
1714
1715        let req = provider.build_request(&context, &options);
1716        let value = serde_json::to_value(&req).expect("serialize");
1717        let msgs = value["messages"].as_array().unwrap();
1718
1719        assert_eq!(msgs[0]["role"], "assistant");
1720        assert_eq!(
1721            msgs[0]["content"].as_str(),
1722            Some("Let me read that file."),
1723            "Assistant text must be preserved when tool_calls are also present"
1724        );
1725        assert_eq!(msgs[0]["tool_calls"].as_array().unwrap().len(), 1);
1726    }
1727
1728    #[test]
1729    fn test_convert_custom_message_to_cohere() {
1730        let context = Context::owned(
1731            None,
1732            vec![Message::Custom(crate::model::CustomMessage {
1733                custom_type: "extension_note".to_string(),
1734                content: "Important context.".to_string(),
1735                display: false,
1736                details: None,
1737                timestamp: 0,
1738            })],
1739            vec![],
1740        );
1741
1742        let msgs = build_cohere_messages(&context);
1743        assert_eq!(msgs.len(), 1);
1744        let value = serde_json::to_value(&msgs[0]).expect("serialize");
1745        assert_eq!(value["role"], "user");
1746        assert_eq!(value["content"], "Important context.");
1747    }
1748
1749    #[test]
1750    fn test_convert_user_blocks_extracts_text_only() {
1751        let content = UserContent::Blocks(vec![
1752            ContentBlock::Text(TextContent::new("part 1")),
1753            ContentBlock::Image(crate::model::ImageContent {
1754                data: "aGVsbG8=".to_string(),
1755                mime_type: "image/png".to_string(),
1756            }),
1757            ContentBlock::Text(TextContent::new("part 2")),
1758        ]);
1759
1760        let text = extract_text_user_content(&content);
1761        assert_eq!(text, "part 1[Image: image/png (8 bytes)]part 2");
1762    }
1763
1764    // ─── Provider builder tests ─────────────────────────────────────────
1765
1766    #[test]
1767    fn test_custom_provider_name() {
1768        let provider = CohereProvider::new("command-r").with_provider_name("my-proxy");
1769        assert_eq!(provider.name(), "my-proxy");
1770        assert_eq!(provider.api(), "cohere-chat");
1771    }
1772
1773    #[test]
1774    fn test_custom_base_url() {
1775        let provider =
1776            CohereProvider::new("command-r").with_base_url("https://proxy.example.com/v2/chat");
1777        assert_eq!(provider.base_url, "https://proxy.example.com/v2/chat");
1778    }
1779
1780    // ─── Stream event parsing tests ─────────────────────────────────────
1781
1782    #[test]
1783    fn test_stream_complete_finish_reason_maps_to_stop() {
1784        let events = vec![
1785            json!({ "type": "message-start", "id": "msg_1" }),
1786            json!({
1787                "type": "message-end",
1788                "delta": {
1789                    "finish_reason": "COMPLETE",
1790                    "usage": { "tokens": { "input_tokens": 5, "output_tokens": 10 } }
1791                }
1792            }),
1793        ];
1794
1795        let out = collect_events(&events);
1796        assert!(out.iter().any(|e| matches!(
1797            e,
1798            StreamEvent::Done {
1799                reason: StopReason::Stop,
1800                message,
1801                ..
1802            } if message.usage.input == 5 && message.usage.output == 10
1803        )));
1804    }
1805
1806    #[test]
1807    fn test_stream_error_finish_reason_maps_to_error() {
1808        let events = vec![
1809            json!({ "type": "message-start", "id": "msg_1" }),
1810            json!({
1811                "type": "message-end",
1812                "delta": {
1813                    "finish_reason": "ERROR",
1814                    "usage": { "tokens": { "input_tokens": 1, "output_tokens": 0 } }
1815                }
1816            }),
1817        ];
1818
1819        let out = collect_events(&events);
1820        assert!(out.iter().any(|e| matches!(
1821            e,
1822            StreamEvent::Done {
1823                reason: StopReason::Error,
1824                ..
1825            }
1826        )));
1827    }
1828
1829    #[test]
1830    fn test_stream_tool_call_with_streamed_arguments() {
1831        let events = vec![
1832            json!({ "type": "message-start", "id": "msg_1" }),
1833            json!({
1834                "type": "tool-call-start",
1835                "delta": {
1836                    "message": {
1837                        "tool_calls": {
1838                            "id": "call_42",
1839                            "type": "function",
1840                            "function": { "name": "bash", "arguments": "{\"co" }
1841                        }
1842                    }
1843                }
1844            }),
1845            json!({
1846                "type": "tool-call-delta",
1847                "delta": {
1848                    "message": {
1849                        "tool_calls": {
1850                            "function": { "arguments": "mmand\"" }
1851                        }
1852                    }
1853                }
1854            }),
1855            json!({
1856                "type": "tool-call-delta",
1857                "delta": {
1858                    "message": {
1859                        "tool_calls": {
1860                            "function": { "arguments": ": \"ls -la\"}" }
1861                        }
1862                    }
1863                }
1864            }),
1865            json!({ "type": "tool-call-end" }),
1866            json!({
1867                "type": "message-end",
1868                "delta": {
1869                    "finish_reason": "TOOL_CALL",
1870                    "usage": { "tokens": { "input_tokens": 10, "output_tokens": 20 } }
1871                }
1872            }),
1873        ];
1874
1875        let out = collect_events(&events);
1876
1877        // Should have ToolCallEnd with properly assembled arguments.
1878        let tool_end = out
1879            .iter()
1880            .find(|e| matches!(e, StreamEvent::ToolCallEnd { .. }));
1881        assert!(tool_end.is_some(), "Expected ToolCallEnd event");
1882        if let Some(StreamEvent::ToolCallEnd { tool_call, .. }) = tool_end {
1883            assert_eq!(tool_call.name, "bash");
1884            assert_eq!(tool_call.id, "call_42");
1885            assert_eq!(tool_call.arguments["command"], "ls -la");
1886        }
1887    }
1888
1889    #[test]
1890    fn test_stream_unknown_event_type_ignored() {
1891        let events = vec![
1892            json!({ "type": "message-start", "id": "msg_1" }),
1893            json!({ "type": "some-future-event", "data": "ignored" }),
1894            json!({
1895                "type": "content-start",
1896                "index": 0,
1897                "delta": { "message": { "content": { "type": "text", "text": "OK" } } }
1898            }),
1899            json!({ "type": "content-end", "index": 0 }),
1900            json!({
1901                "type": "message-end",
1902                "delta": {
1903                    "finish_reason": "COMPLETE",
1904                    "usage": { "tokens": { "input_tokens": 1, "output_tokens": 1 } }
1905                }
1906            }),
1907        ];
1908
1909        let out = collect_events(&events);
1910        // Should complete successfully despite unknown event.
1911        assert!(out.iter().any(|e| matches!(e, StreamEvent::Done { .. })));
1912        assert!(out.iter().any(|e| matches!(
1913            e,
1914            StreamEvent::TextStart {
1915                content_index: 0,
1916                ..
1917            }
1918        )));
1919    }
1920}
1921
1922// ============================================================================
1923// Fuzzing support
1924// ============================================================================
1925
1926#[cfg(feature = "fuzzing")]
1927pub mod fuzz {
1928    use super::*;
1929    use futures::stream;
1930    use std::pin::Pin;
1931
1932    type FuzzStream =
1933        Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
1934
1935    /// Opaque wrapper around the Cohere stream processor state.
1936    pub struct Processor(StreamState<FuzzStream>);
1937
1938    impl Default for Processor {
1939        fn default() -> Self {
1940            Self::new()
1941        }
1942    }
1943
1944    impl Processor {
1945        /// Create a fresh processor with default state.
1946        pub fn new() -> Self {
1947            let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1948            Self(StreamState::new(
1949                crate::sse::SseStream::new(Box::pin(empty)),
1950                "cohere-fuzz".into(),
1951                "cohere".into(),
1952                "cohere".into(),
1953            ))
1954        }
1955
1956        /// Feed one SSE data payload and return any emitted `StreamEvent`s.
1957        pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
1958            self.0.process_event(data)?;
1959            Ok(self.0.pending_events.drain(..).collect())
1960        }
1961    }
1962}