1use crate::error::{Error, Result};
7use crate::http::client::Client;
8use crate::model::{
9 AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, TextContent, ThinkingContent,
10 ToolCall, Usage, UserContent,
11};
12use crate::models::CompatConfig;
13use crate::provider::{Context, Provider, StreamOptions, ToolDef};
14use crate::sse::SseStream;
15use async_trait::async_trait;
16use futures::StreamExt;
17use futures::stream::{self, Stream};
18use serde::{Deserialize, Serialize};
19use std::collections::{HashMap, VecDeque};
20use std::pin::Pin;
21
22const COHERE_CHAT_API_URL: &str = "https://api.cohere.com/v2/chat";
27const DEFAULT_MAX_TOKENS: u32 = 4096;
28
29pub struct CohereProvider {
35 client: Client,
36 model: String,
37 base_url: String,
38 provider: String,
39 compat: Option<CompatConfig>,
40}
41
42impl CohereProvider {
43 pub fn new(model: impl Into<String>) -> Self {
44 Self {
45 client: Client::new(),
46 model: model.into(),
47 base_url: COHERE_CHAT_API_URL.to_string(),
48 provider: "cohere".to_string(),
49 compat: None,
50 }
51 }
52
53 #[must_use]
54 pub fn with_provider_name(mut self, provider: impl Into<String>) -> Self {
55 self.provider = provider.into();
56 self
57 }
58
59 #[must_use]
60 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
61 self.base_url = base_url.into();
62 self
63 }
64
65 #[must_use]
66 pub fn with_client(mut self, client: Client) -> Self {
67 self.client = client;
68 self
69 }
70
71 #[must_use]
73 pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
74 self.compat = compat;
75 self
76 }
77
78 pub fn build_request(&self, context: &Context<'_>, options: &StreamOptions) -> CohereRequest {
79 let messages = build_cohere_messages(context);
80
81 let tools: Option<Vec<CohereTool>> = if context.tools.is_empty() {
82 None
83 } else {
84 Some(context.tools.iter().map(convert_tool_to_cohere).collect())
85 };
86
87 CohereRequest {
88 model: self.model.clone(),
89 messages,
90 max_tokens: options.max_tokens.or(Some(DEFAULT_MAX_TOKENS)),
91 temperature: options.temperature,
92 tools,
93 stream: true,
94 }
95 }
96}
97
98fn authorization_override(
99 options: &StreamOptions,
100 compat: Option<&CompatConfig>,
101) -> Option<String> {
102 super::first_non_empty_header_value_case_insensitive(&options.headers, &["authorization"])
103 .or_else(|| {
104 compat
105 .and_then(|compat| compat.custom_headers.as_ref())
106 .and_then(|headers| {
107 super::first_non_empty_header_value_case_insensitive(
108 headers,
109 &["authorization"],
110 )
111 })
112 })
113}
114
115#[async_trait]
116#[allow(clippy::too_many_lines)]
117impl Provider for CohereProvider {
118 fn name(&self) -> &str {
119 &self.provider
120 }
121
122 fn api(&self) -> &'static str {
123 "cohere-chat"
124 }
125
126 fn model_id(&self) -> &str {
127 &self.model
128 }
129
130 async fn stream(
131 &self,
132 context: &Context<'_>,
133 options: &StreamOptions,
134 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
135 let authorization_override = authorization_override(options, self.compat.as_ref());
136
137 let auth_value = if authorization_override.is_some() {
138 None
139 } else {
140 Some(
141 options
142 .api_key
143 .clone()
144 .or_else(|| std::env::var("COHERE_API_KEY").ok())
145 .ok_or_else(|| Error::provider("cohere", "Missing API key for provider. Configure credentials with /login <provider> or set the provider's API key env var."))?,
146 )
147 };
148
149 let request_body = self.build_request(context, options);
150
151 let mut request = self
153 .client
154 .post(&self.base_url)
155 .header("Accept", "text/event-stream");
156
157 if let Some(auth_value) = auth_value {
158 request = request.header("Authorization", format!("Bearer {auth_value}"));
159 }
160
161 if let Some(compat) = &self.compat {
163 if let Some(custom_headers) = &compat.custom_headers {
164 request = super::apply_headers_ignoring_blank_auth_overrides(
165 request,
166 custom_headers,
167 &["authorization"],
168 );
169 }
170 }
171
172 request = super::apply_headers_ignoring_blank_auth_overrides(
174 request,
175 &options.headers,
176 &["authorization"],
177 );
178
179 let request = request.json(&request_body)?;
180
181 let response = Box::pin(request.send()).await?;
182 let status = response.status();
183 if !(200..300).contains(&status) {
184 let body = response
185 .text()
186 .await
187 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
188 return Err(Error::provider(
189 "cohere",
190 format!("Cohere API error (HTTP {status}): {body}"),
191 ));
192 }
193
194 let content_type = response
195 .headers()
196 .iter()
197 .find(|(name, _)| name.eq_ignore_ascii_case("content-type"))
198 .map(|(_, value)| value.to_ascii_lowercase());
199 if !content_type
200 .as_deref()
201 .is_some_and(|value| value.contains("text/event-stream"))
202 {
203 let message = content_type.map_or_else(
204 || {
205 format!(
206 "Cohere API protocol error (HTTP {status}): missing Content-Type (expected text/event-stream)"
207 )
208 },
209 |value| {
210 format!(
211 "Cohere API protocol error (HTTP {status}): unexpected Content-Type {value} (expected text/event-stream)"
212 )
213 },
214 );
215 return Err(Error::api(message));
216 }
217
218 let event_source = SseStream::new(response.bytes_stream());
219
220 let model = self.model.clone();
221 let api = self.api().to_string();
222 let provider = self.name().to_string();
223
224 let stream = stream::unfold(
225 StreamState::new(event_source, model, api, provider),
226 |mut state| async move {
227 loop {
228 if let Some(event) = state.pending_events.pop_front() {
229 return Some((Ok(event), state));
230 }
231
232 if state.finished {
233 return None;
234 }
235
236 match state.event_source.next().await {
237 Some(Ok(msg)) => {
238 state.transient_error_count = 0;
239 if msg.data == "[DONE]" {
240 state.finish();
241 continue;
242 }
243
244 if let Err(e) = state.process_event(&msg.data) {
245 state.finished = true;
246 return Some((Err(e), state));
247 }
248 }
249 Some(Err(e)) => {
250 const MAX_CONSECUTIVE_TRANSIENT_ERRORS: usize = 5;
254 if e.kind() == std::io::ErrorKind::WriteZero
255 || e.kind() == std::io::ErrorKind::WouldBlock
256 || e.kind() == std::io::ErrorKind::TimedOut
257 {
258 state.transient_error_count += 1;
259 if state.transient_error_count <= MAX_CONSECUTIVE_TRANSIENT_ERRORS {
260 tracing::warn!(
261 kind = ?e.kind(),
262 count = state.transient_error_count,
263 "Transient error in SSE stream, continuing"
264 );
265 continue;
266 }
267 tracing::warn!(
268 kind = ?e.kind(),
269 "Error persisted after {MAX_CONSECUTIVE_TRANSIENT_ERRORS} \
270 consecutive attempts, treating as fatal"
271 );
272 }
273 state.finished = true;
274 let err = Error::api(format!("SSE error: {e}"));
275 return Some((Err(err), state));
276 }
277 None => {
278 return Some((
280 Err(Error::api("Stream ended without Done event")),
281 state,
282 ));
283 }
284 }
285 }
286 },
287 );
288
289 Ok(Box::pin(stream))
290 }
291}
292
293struct ToolCallAccum {
298 content_index: usize,
299 id: String,
300 name: String,
301 arguments: String,
302}
303
304struct StreamState<S>
305where
306 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
307{
308 event_source: SseStream<S>,
309 partial: AssistantMessage,
310 pending_events: VecDeque<StreamEvent>,
311 started: bool,
312 finished: bool,
313 content_index_map: HashMap<u32, usize>,
314 active_tool_call: Option<ToolCallAccum>,
315 transient_error_count: usize,
317}
318
319impl<S> StreamState<S>
320where
321 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
322{
323 fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
324 Self {
325 event_source,
326 partial: AssistantMessage {
327 content: Vec::new(),
328 api,
329 provider,
330 model,
331 usage: Usage::default(),
332 stop_reason: StopReason::Stop,
333 error_message: None,
334 timestamp: chrono::Utc::now().timestamp_millis(),
335 },
336 pending_events: VecDeque::new(),
337 started: false,
338 finished: false,
339 content_index_map: HashMap::new(),
340 active_tool_call: None,
341 transient_error_count: 0,
342 }
343 }
344
345 fn ensure_started(&mut self) {
346 if !self.started {
347 self.started = true;
348 self.pending_events.push_back(StreamEvent::Start {
349 partial: self.partial.clone(),
350 });
351 }
352 }
353
354 fn content_block_for(&mut self, idx: u32, kind: CohereContentKind) -> usize {
355 if let Some(existing) = self.content_index_map.get(&idx) {
356 return *existing;
357 }
358
359 let content_index = self.partial.content.len();
360 match kind {
361 CohereContentKind::Text => {
362 self.partial
363 .content
364 .push(ContentBlock::Text(TextContent::new("")));
365 self.pending_events
366 .push_back(StreamEvent::TextStart { content_index });
367 }
368 CohereContentKind::Thinking => {
369 self.partial
370 .content
371 .push(ContentBlock::Thinking(ThinkingContent {
372 thinking: String::new(),
373 thinking_signature: None,
374 }));
375 self.pending_events
376 .push_back(StreamEvent::ThinkingStart { content_index });
377 }
378 }
379
380 self.content_index_map.insert(idx, content_index);
381 content_index
382 }
383
384 #[allow(clippy::too_many_lines)]
385 fn process_event(&mut self, data: &str) -> Result<()> {
386 let chunk: CohereStreamChunk = serde_json::from_str(data)
387 .map_err(|e| Error::api(format!("JSON parse error: {e}\nData: {data}")))?;
388
389 match chunk {
390 CohereStreamChunk::MessageStart { .. } => {
391 self.ensure_started();
392 }
393 CohereStreamChunk::ContentStart { index, delta } => {
394 self.ensure_started();
395 let (kind, initial) = delta.message.content.kind_and_text();
396 let content_index = self.content_block_for(index, kind);
397
398 if !initial.is_empty() {
399 match kind {
400 CohereContentKind::Text => {
401 if let Some(ContentBlock::Text(t)) =
402 self.partial.content.get_mut(content_index)
403 {
404 t.text.push_str(&initial);
405 }
406 self.pending_events.push_back(StreamEvent::TextDelta {
407 content_index,
408 delta: initial,
409 });
410 }
411 CohereContentKind::Thinking => {
412 if let Some(ContentBlock::Thinking(t)) =
413 self.partial.content.get_mut(content_index)
414 {
415 t.thinking.push_str(&initial);
416 }
417 self.pending_events.push_back(StreamEvent::ThinkingDelta {
418 content_index,
419 delta: initial,
420 });
421 }
422 }
423 }
424 }
425 CohereStreamChunk::ContentDelta { index, delta } => {
426 self.ensure_started();
427 let (kind, delta_text) = delta.message.content.kind_and_text();
428 let content_index = self.content_block_for(index, kind);
429
430 match kind {
431 CohereContentKind::Text => {
432 if let Some(ContentBlock::Text(t)) =
433 self.partial.content.get_mut(content_index)
434 {
435 t.text.push_str(&delta_text);
436 }
437 self.pending_events.push_back(StreamEvent::TextDelta {
438 content_index,
439 delta: delta_text,
440 });
441 }
442 CohereContentKind::Thinking => {
443 if let Some(ContentBlock::Thinking(t)) =
444 self.partial.content.get_mut(content_index)
445 {
446 t.thinking.push_str(&delta_text);
447 }
448 self.pending_events.push_back(StreamEvent::ThinkingDelta {
449 content_index,
450 delta: delta_text,
451 });
452 }
453 }
454 }
455 CohereStreamChunk::ContentEnd { index } => {
456 if let Some(content_index) = self.content_index_map.get(&index).copied() {
457 match self.partial.content.get(content_index) {
458 Some(ContentBlock::Text(t)) => {
459 self.pending_events.push_back(StreamEvent::TextEnd {
460 content_index,
461 content: t.text.clone(),
462 });
463 }
464 Some(ContentBlock::Thinking(t)) => {
465 self.pending_events.push_back(StreamEvent::ThinkingEnd {
466 content_index,
467 content: t.thinking.clone(),
468 });
469 }
470 _ => {}
471 }
472 }
473 }
474 CohereStreamChunk::ToolCallStart { delta } => {
475 self.ensure_started();
476 let tc = delta.message.tool_calls;
477 let content_index = self.partial.content.len();
478 self.partial.content.push(ContentBlock::ToolCall(ToolCall {
479 id: tc.id.clone(),
480 name: tc.function.name.clone(),
481 arguments: serde_json::Value::Null,
482 thought_signature: None,
483 }));
484
485 self.active_tool_call = Some(ToolCallAccum {
486 content_index,
487 id: tc.id,
488 name: tc.function.name,
489 arguments: tc.function.arguments.clone(),
490 });
491
492 self.pending_events
493 .push_back(StreamEvent::ToolCallStart { content_index });
494 if !tc.function.arguments.is_empty() {
495 self.pending_events.push_back(StreamEvent::ToolCallDelta {
496 content_index,
497 delta: tc.function.arguments,
498 });
499 }
500 }
501 CohereStreamChunk::ToolCallDelta { delta } => {
502 self.ensure_started();
503 if let Some(active) = self.active_tool_call.as_mut() {
504 active
505 .arguments
506 .push_str(&delta.message.tool_calls.function.arguments);
507 self.pending_events.push_back(StreamEvent::ToolCallDelta {
508 content_index: active.content_index,
509 delta: delta.message.tool_calls.function.arguments,
510 });
511 }
512 }
513 CohereStreamChunk::ToolCallEnd => {
514 if let Some(active) = self.active_tool_call.take() {
515 self.ensure_started();
516 let parsed_args: serde_json::Value = serde_json::from_str(&active.arguments)
517 .unwrap_or_else(|e| {
518 tracing::warn!(
519 error = %e,
520 raw = %active.arguments,
521 "Failed to parse tool arguments as JSON"
522 );
523 serde_json::Value::Null
524 });
525
526 self.partial.stop_reason = StopReason::ToolUse;
527 self.pending_events.push_back(StreamEvent::ToolCallEnd {
528 content_index: active.content_index,
529 tool_call: ToolCall {
530 id: active.id,
531 name: active.name,
532 arguments: parsed_args.clone(),
533 thought_signature: None,
534 },
535 });
536
537 if let Some(ContentBlock::ToolCall(block)) =
538 self.partial.content.get_mut(active.content_index)
539 {
540 block.arguments = parsed_args;
541 }
542 }
543 }
544 CohereStreamChunk::MessageEnd { delta } => {
545 self.ensure_started();
546 self.partial.usage.input = delta.usage.tokens.input_tokens;
547 self.partial.usage.output = delta.usage.tokens.output_tokens;
548 self.partial.usage.total_tokens =
549 delta.usage.tokens.input_tokens + delta.usage.tokens.output_tokens;
550
551 self.partial.stop_reason = match delta.finish_reason.as_str() {
552 "MAX_TOKENS" => StopReason::Length,
553 "TOOL_CALL" => StopReason::ToolUse,
554 "ERROR" => StopReason::Error,
555 _ => StopReason::Stop,
556 };
557
558 self.finish();
559 }
560 CohereStreamChunk::Unknown => {}
561 }
562
563 Ok(())
564 }
565
566 fn finish(&mut self) {
567 if self.finished {
568 return;
569 }
570 let reason = self.partial.stop_reason;
571 self.pending_events.push_back(StreamEvent::Done {
572 reason,
573 message: std::mem::take(&mut self.partial),
574 });
575 self.finished = true;
576 }
577}
578
579#[derive(Debug, Serialize)]
584pub struct CohereRequest {
585 model: String,
586 messages: Vec<CohereMessage>,
587 #[serde(skip_serializing_if = "Option::is_none")]
588 max_tokens: Option<u32>,
589 #[serde(skip_serializing_if = "Option::is_none")]
590 temperature: Option<f32>,
591 #[serde(skip_serializing_if = "Option::is_none")]
592 tools: Option<Vec<CohereTool>>,
593 stream: bool,
594}
595
596#[derive(Debug, Serialize)]
597#[serde(tag = "role", rename_all = "lowercase")]
598enum CohereMessage {
599 System {
600 content: String,
601 },
602 User {
603 content: String,
604 },
605 Assistant {
606 #[serde(skip_serializing_if = "Option::is_none")]
607 content: Option<String>,
608 #[serde(skip_serializing_if = "Option::is_none")]
609 tool_calls: Option<Vec<CohereToolCallRef>>,
610 #[serde(skip_serializing_if = "Option::is_none")]
611 tool_plan: Option<String>,
612 },
613 Tool {
614 content: String,
615 tool_call_id: String,
616 },
617}
618
619#[derive(Debug, Serialize)]
620struct CohereToolCallRef {
621 id: String,
622 #[serde(rename = "type")]
623 r#type: &'static str,
624 function: CohereFunctionRef,
625}
626
627#[derive(Debug, Serialize)]
628struct CohereFunctionRef {
629 name: String,
630 arguments: String,
631}
632
633#[derive(Debug, Serialize)]
634struct CohereTool {
635 #[serde(rename = "type")]
636 r#type: &'static str,
637 function: CohereFunction,
638}
639
640#[derive(Debug, Serialize)]
641struct CohereFunction {
642 name: String,
643 #[serde(skip_serializing_if = "Option::is_none")]
644 description: Option<String>,
645 parameters: serde_json::Value,
646}
647
648fn convert_tool_to_cohere(tool: &ToolDef) -> CohereTool {
649 CohereTool {
650 r#type: "function",
651 function: CohereFunction {
652 name: tool.name.clone(),
653 description: if tool.description.trim().is_empty() {
654 None
655 } else {
656 Some(tool.description.clone())
657 },
658 parameters: tool.parameters.clone(),
659 },
660 }
661}
662
663fn build_cohere_messages(context: &Context<'_>) -> Vec<CohereMessage> {
664 let mut out = Vec::new();
665
666 if let Some(system) = &context.system_prompt {
667 out.push(CohereMessage::System {
668 content: system.to_string(),
669 });
670 }
671
672 for message in context.messages.iter() {
673 match message {
674 Message::User(user) => out.push(CohereMessage::User {
675 content: extract_text_user_content(&user.content),
676 }),
677 Message::Custom(custom) => out.push(CohereMessage::User {
678 content: custom.content.clone(),
679 }),
680 Message::Assistant(assistant) => {
681 let mut text = String::new();
682 let mut tool_calls = Vec::new();
683
684 for block in &assistant.content {
685 match block {
686 ContentBlock::Text(t) => text.push_str(&t.text),
687 ContentBlock::ToolCall(tc) => tool_calls.push(CohereToolCallRef {
688 id: tc.id.clone(),
689 r#type: "function",
690 function: CohereFunctionRef {
691 name: tc.name.clone(),
692 arguments: tc.arguments.to_string(),
693 },
694 }),
695 _ => {}
696 }
697 }
698
699 out.push(CohereMessage::Assistant {
700 content: if text.is_empty() { None } else { Some(text) },
701 tool_calls: if tool_calls.is_empty() {
702 None
703 } else {
704 Some(tool_calls)
705 },
706 tool_plan: None,
707 });
708 }
709 Message::ToolResult(result) => {
710 let mut content = String::new();
711 for (i, block) in result.content.iter().enumerate() {
712 if i > 0 {
713 content.push('\n');
714 }
715 if let ContentBlock::Text(t) = block {
716 content.push_str(&t.text);
717 }
718 }
719 out.push(CohereMessage::Tool {
720 content,
721 tool_call_id: result.tool_call_id.clone(),
722 });
723 }
724 }
725 }
726
727 out
728}
729
730fn extract_text_user_content(content: &UserContent) -> String {
731 match content {
732 UserContent::Text(text) => text.clone(),
733 UserContent::Blocks(blocks) => {
734 let mut out = String::new();
735 for block in blocks {
736 match block {
737 ContentBlock::Text(t) => out.push_str(&t.text),
738 ContentBlock::Image(img) => {
739 use std::fmt::Write as _;
740 let _ =
741 write!(out, "[Image: {} ({} bytes)]", img.mime_type, img.data.len());
742 }
743 _ => {}
744 }
745 }
746 out
747 }
748 }
749}
750
751#[derive(Debug, Deserialize)]
756#[serde(tag = "type")]
757enum CohereStreamChunk {
758 #[serde(rename = "message-start")]
759 MessageStart { id: Option<String> },
760 #[serde(rename = "content-start")]
761 ContentStart {
762 index: u32,
763 delta: CohereContentStartDelta,
764 },
765 #[serde(rename = "content-delta")]
766 ContentDelta {
767 index: u32,
768 delta: CohereContentDelta,
769 },
770 #[serde(rename = "content-end")]
771 ContentEnd { index: u32 },
772 #[serde(rename = "tool-call-start")]
773 ToolCallStart { delta: CohereToolCallStartDelta },
774 #[serde(rename = "tool-call-delta")]
775 ToolCallDelta { delta: CohereToolCallDelta },
776 #[serde(rename = "tool-call-end")]
777 ToolCallEnd,
778 #[serde(rename = "message-end")]
779 MessageEnd { delta: CohereMessageEndDelta },
780 #[serde(other)]
781 Unknown,
782}
783
784#[derive(Debug, Deserialize)]
785struct CohereContentStartDelta {
786 message: CohereDeltaMessage<CohereContentStart>,
787}
788
789#[derive(Debug, Deserialize)]
790struct CohereContentDelta {
791 message: CohereDeltaMessage<CohereContentDeltaPart>,
792}
793
794#[derive(Debug, Deserialize)]
795struct CohereDeltaMessage<T> {
796 content: T,
797}
798
799#[derive(Debug, Deserialize)]
800#[serde(tag = "type")]
801enum CohereContentStart {
802 #[serde(rename = "text")]
803 Text { text: String },
804 #[serde(rename = "thinking")]
805 Thinking { thinking: String },
806}
807
808#[derive(Debug, Deserialize)]
809#[serde(untagged)]
810enum CohereContentDeltaPart {
811 Text { text: String },
812 Thinking { thinking: String },
813}
814
815#[derive(Debug, Clone, Copy)]
816enum CohereContentKind {
817 Text,
818 Thinking,
819}
820
821impl CohereContentStart {
822 fn kind_and_text(self) -> (CohereContentKind, String) {
823 match self {
824 Self::Text { text } => (CohereContentKind::Text, text),
825 Self::Thinking { thinking } => (CohereContentKind::Thinking, thinking),
826 }
827 }
828}
829
830impl CohereContentDeltaPart {
831 fn kind_and_text(self) -> (CohereContentKind, String) {
832 match self {
833 Self::Text { text } => (CohereContentKind::Text, text),
834 Self::Thinking { thinking } => (CohereContentKind::Thinking, thinking),
835 }
836 }
837}
838
839#[derive(Debug, Deserialize)]
840struct CohereToolCallStartDelta {
841 message: CohereToolCallMessage<CohereToolCallStartBody>,
842}
843
844#[derive(Debug, Deserialize)]
845struct CohereToolCallDelta {
846 message: CohereToolCallMessage<CohereToolCallDeltaBody>,
847}
848
849#[derive(Debug, Deserialize)]
850struct CohereToolCallMessage<T> {
851 tool_calls: T,
852}
853
854#[derive(Debug, Deserialize)]
855struct CohereToolCallStartBody {
856 id: String,
857 function: CohereToolCallFunctionStart,
858}
859
860#[derive(Debug, Deserialize)]
861struct CohereToolCallFunctionStart {
862 name: String,
863 arguments: String,
864}
865
866#[derive(Debug, Deserialize)]
867struct CohereToolCallDeltaBody {
868 function: CohereToolCallFunctionDelta,
869}
870
871#[derive(Debug, Deserialize)]
872struct CohereToolCallFunctionDelta {
873 arguments: String,
874}
875
876#[derive(Debug, Deserialize)]
877struct CohereMessageEndDelta {
878 finish_reason: String,
879 usage: CohereUsage,
880}
881
882#[derive(Debug, Deserialize)]
883struct CohereUsage {
884 tokens: CohereUsageTokens,
885}
886
887#[derive(Debug, Deserialize)]
888struct CohereUsageTokens {
889 input_tokens: u64,
890 output_tokens: u64,
891}
892
893#[cfg(test)]
898mod tests {
899 use super::*;
900 use asupersync::runtime::RuntimeBuilder;
901 use futures::stream;
902 use serde_json::{Value, json};
903 use std::collections::HashMap;
904 use std::io::{Read, Write};
905 use std::net::TcpListener;
906 use std::path::PathBuf;
907 use std::sync::mpsc;
908 use std::time::Duration;
909
910 #[derive(Debug, Deserialize)]
913 struct ProviderFixture {
914 cases: Vec<ProviderCase>,
915 }
916
917 #[derive(Debug, Deserialize)]
918 struct ProviderCase {
919 name: String,
920 events: Vec<Value>,
921 expected: Vec<EventSummary>,
922 }
923
924 #[derive(Debug, Deserialize, Serialize, PartialEq)]
925 struct EventSummary {
926 kind: String,
927 #[serde(default)]
928 content_index: Option<usize>,
929 #[serde(default)]
930 delta: Option<String>,
931 #[serde(default)]
932 content: Option<String>,
933 #[serde(default)]
934 reason: Option<String>,
935 }
936
937 fn load_fixture(file_name: &str) -> ProviderFixture {
938 let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
939 .join("tests/fixtures/provider_responses")
940 .join(file_name);
941 let raw = std::fs::read_to_string(path).expect("fixture read");
942 serde_json::from_str(&raw).expect("fixture parse")
943 }
944
945 fn summarize_event(event: &StreamEvent) -> EventSummary {
946 match event {
947 StreamEvent::Start { .. } => EventSummary {
948 kind: "start".to_string(),
949 content_index: None,
950 delta: None,
951 content: None,
952 reason: None,
953 },
954 StreamEvent::TextStart { content_index, .. } => EventSummary {
955 kind: "text_start".to_string(),
956 content_index: Some(*content_index),
957 delta: None,
958 content: None,
959 reason: None,
960 },
961 StreamEvent::TextDelta {
962 content_index,
963 delta,
964 ..
965 } => EventSummary {
966 kind: "text_delta".to_string(),
967 content_index: Some(*content_index),
968 delta: Some(delta.clone()),
969 content: None,
970 reason: None,
971 },
972 StreamEvent::TextEnd {
973 content_index,
974 content,
975 ..
976 } => EventSummary {
977 kind: "text_end".to_string(),
978 content_index: Some(*content_index),
979 delta: None,
980 content: Some(content.clone()),
981 reason: None,
982 },
983 StreamEvent::Done { reason, .. } => EventSummary {
984 kind: "done".to_string(),
985 content_index: None,
986 delta: None,
987 content: None,
988 reason: Some(reason_to_string(*reason)),
989 },
990 StreamEvent::Error { reason, .. } => EventSummary {
991 kind: "error".to_string(),
992 content_index: None,
993 delta: None,
994 content: None,
995 reason: Some(reason_to_string(*reason)),
996 },
997 _ => EventSummary {
998 kind: "other".to_string(),
999 content_index: None,
1000 delta: None,
1001 content: None,
1002 reason: None,
1003 },
1004 }
1005 }
1006
1007 fn reason_to_string(reason: StopReason) -> String {
1008 match reason {
1009 StopReason::Stop => "stop",
1010 StopReason::Length => "length",
1011 StopReason::ToolUse => "tool_use",
1012 StopReason::Error => "error",
1013 StopReason::Aborted => "aborted",
1014 }
1015 .to_string()
1016 }
1017
1018 #[test]
1019 fn test_stream_fixtures() {
1020 let fixture = load_fixture("cohere_stream.json");
1021 for case in fixture.cases {
1022 let events = collect_events(&case.events);
1023 let summaries: Vec<EventSummary> = events.iter().map(summarize_event).collect();
1024 assert_eq!(summaries, case.expected, "case {}", case.name);
1025 }
1026 }
1027
1028 #[test]
1031 fn test_provider_info() {
1032 let provider = CohereProvider::new("command-r");
1033 assert_eq!(provider.name(), "cohere");
1034 assert_eq!(provider.api(), "cohere-chat");
1035 }
1036
1037 #[test]
1038 fn test_build_request_includes_system_tools_and_v2_shape() {
1039 let provider = CohereProvider::new("command-r");
1040 let context = Context::owned(
1041 Some("You are concise.".to_string()),
1042 vec![Message::User(crate::model::UserMessage {
1043 content: UserContent::Text("Ping".to_string()),
1044 timestamp: 0,
1045 })],
1046 vec![ToolDef {
1047 name: "search".to_string(),
1048 description: "Search docs".to_string(),
1049 parameters: json!({
1050 "type": "object",
1051 "properties": {
1052 "q": { "type": "string" }
1053 },
1054 "required": ["q"]
1055 }),
1056 }],
1057 );
1058 let options = StreamOptions {
1059 temperature: Some(0.2),
1060 max_tokens: Some(123),
1061 ..Default::default()
1062 };
1063
1064 let request = provider.build_request(&context, &options);
1065 let value = serde_json::to_value(&request).expect("serialize request");
1066
1067 assert_eq!(value["model"], "command-r");
1068 assert_eq!(value["messages"][0]["role"], "system");
1069 assert_eq!(value["messages"][0]["content"], "You are concise.");
1070 assert_eq!(value["messages"][1]["role"], "user");
1071 assert_eq!(value["messages"][1]["content"], "Ping");
1072 assert_eq!(value["stream"], true);
1073 assert_eq!(value["max_tokens"], 123);
1074 let temperature = value["temperature"]
1075 .as_f64()
1076 .expect("temperature should be numeric");
1077 assert!((temperature - 0.2).abs() < 1e-6);
1078 assert_eq!(value["tools"][0]["type"], "function");
1079 assert_eq!(value["tools"][0]["function"]["name"], "search");
1080 assert_eq!(value["tools"][0]["function"]["description"], "Search docs");
1081 assert_eq!(
1082 value["tools"][0]["function"]["parameters"],
1083 json!({
1084 "type": "object",
1085 "properties": {
1086 "q": { "type": "string" }
1087 },
1088 "required": ["q"]
1089 })
1090 );
1091 }
1092
1093 #[test]
1094 fn test_convert_tool_to_cohere_omits_empty_description() {
1095 let tool = ToolDef {
1096 name: "echo".to_string(),
1097 description: " ".to_string(),
1098 parameters: json!({
1099 "type": "object",
1100 "properties": {
1101 "text": { "type": "string" }
1102 }
1103 }),
1104 };
1105
1106 let converted = convert_tool_to_cohere(&tool);
1107 let value = serde_json::to_value(converted).expect("serialize converted tool");
1108 assert_eq!(value["type"], "function");
1109 assert_eq!(value["function"]["name"], "echo");
1110 assert!(value["function"].get("description").is_none());
1111 }
1112
1113 #[test]
1114 fn test_stream_parses_text_and_tool_call() {
1115 let runtime = RuntimeBuilder::current_thread()
1116 .build()
1117 .expect("runtime build");
1118
1119 runtime.block_on(async move {
1120 let events = [
1121 serde_json::json!({ "type": "message-start", "id": "msg_1" }),
1122 serde_json::json!({
1123 "type": "content-start",
1124 "index": 0,
1125 "delta": { "message": { "content": { "type": "text", "text": "Hello" } } }
1126 }),
1127 serde_json::json!({
1128 "type": "content-delta",
1129 "index": 0,
1130 "delta": { "message": { "content": { "text": " world" } } }
1131 }),
1132 serde_json::json!({ "type": "content-end", "index": 0 }),
1133 serde_json::json!({
1134 "type": "tool-call-start",
1135 "delta": { "message": { "tool_calls": { "id": "call_1", "type": "function", "function": { "name": "echo", "arguments": "{\"text\":\"hi\"}" } } } }
1136 }),
1137 serde_json::json!({ "type": "tool-call-end" }),
1138 serde_json::json!({
1139 "type": "message-end",
1140 "delta": { "finish_reason": "TOOL_CALL", "usage": { "tokens": { "input_tokens": 1, "output_tokens": 2 } } }
1141 }),
1142 ];
1143
1144 let byte_stream = stream::iter(
1145 events
1146 .iter()
1147 .map(|event| format!("data: {}\n\n", serde_json::to_string(event).unwrap()))
1148 .map(|s| Ok(s.into_bytes())),
1149 );
1150
1151 let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1152 let mut state = StreamState::new(
1153 event_source,
1154 "command-r".to_string(),
1155 "cohere-chat".to_string(),
1156 "cohere".to_string(),
1157 );
1158
1159 let mut out = Vec::new();
1160 while let Some(item) = state.event_source.next().await {
1161 let msg = item.expect("SSE event");
1162 state.process_event(&msg.data).expect("process_event");
1163 out.extend(state.pending_events.drain(..));
1164 if state.finished {
1165 break;
1166 }
1167 }
1168
1169 assert!(matches!(out.first(), Some(StreamEvent::Start { .. })));
1170 assert!(out.iter().any(|e| matches!(e, StreamEvent::TextDelta { delta, .. } if delta.contains("Hello"))));
1171 assert!(out.iter().any(|e| matches!(e, StreamEvent::ToolCallEnd { tool_call, .. } if tool_call.name == "echo")));
1172 assert!(out.iter().any(|e| matches!(e, StreamEvent::Done { reason: StopReason::ToolUse, .. })));
1173 });
1174 }
1175
1176 #[test]
1177 fn test_stream_parses_thinking_and_max_tokens_stop_reason() {
1178 let events = vec![
1179 json!({ "type": "message-start", "id": "msg_1" }),
1180 json!({
1181 "type": "content-start",
1182 "index": 0,
1183 "delta": { "message": { "content": { "type": "thinking", "thinking": "Plan" } } }
1184 }),
1185 json!({
1186 "type": "content-delta",
1187 "index": 0,
1188 "delta": { "message": { "content": { "thinking": " more" } } }
1189 }),
1190 json!({ "type": "content-end", "index": 0 }),
1191 json!({
1192 "type": "message-end",
1193 "delta": { "finish_reason": "MAX_TOKENS", "usage": { "tokens": { "input_tokens": 2, "output_tokens": 3 } } }
1194 }),
1195 ];
1196
1197 let out = collect_events(&events);
1198 assert!(
1199 out.iter()
1200 .any(|e| matches!(e, StreamEvent::ThinkingStart { .. }))
1201 );
1202 assert!(out.iter().any(
1203 |e| matches!(e, StreamEvent::ThinkingDelta { delta, .. } if delta.contains("Plan"))
1204 ));
1205 assert!(
1206 out.iter()
1207 .any(|e| matches!(e, StreamEvent::ThinkingEnd { content, .. } if content.contains("Plan more")))
1208 );
1209 assert!(out.iter().any(|e| matches!(
1210 e,
1211 StreamEvent::Done {
1212 reason: StopReason::Length,
1213 ..
1214 }
1215 )));
1216 }
1217
1218 #[test]
1219 fn test_stream_sets_bearer_auth_header() {
1220 let captured = run_stream_and_capture_headers(Some("test-cohere-key"), HashMap::new())
1221 .expect("captured request");
1222 assert_eq!(
1223 captured.headers.get("authorization").map(String::as_str),
1224 Some("Bearer test-cohere-key")
1225 );
1226 assert_eq!(
1227 captured.headers.get("accept").map(String::as_str),
1228 Some("text/event-stream")
1229 );
1230
1231 let body: Value = serde_json::from_str(&captured.body).expect("body json");
1232 assert_eq!(body["model"], "command-r");
1233 assert_eq!(body["stream"], true);
1234 }
1235
1236 #[test]
1237 fn test_stream_uses_existing_authorization_header_without_api_key() {
1238 let mut headers = HashMap::new();
1239 headers.insert(
1240 "Authorization".to_string(),
1241 "Bearer from-custom-header".to_string(),
1242 );
1243 headers.insert("X-Test".to_string(), "1".to_string());
1244
1245 let captured = run_stream_and_capture_headers(None, headers).expect("captured request");
1246 assert_eq!(
1247 captured.headers.get("authorization").map(String::as_str),
1248 Some("Bearer from-custom-header")
1249 );
1250 assert_eq!(
1251 captured.headers.get("x-test").map(String::as_str),
1252 Some("1")
1253 );
1254 }
1255
1256 #[test]
1257 fn test_stream_compat_authorization_header_overrides_api_key_without_duplicate() {
1258 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1259 let mut custom_headers = HashMap::new();
1260 custom_headers.insert(
1261 "Authorization".to_string(),
1262 "Bearer compat-header".to_string(),
1263 );
1264 let provider = CohereProvider::new("command-r")
1265 .with_base_url(base_url)
1266 .with_compat(Some(CompatConfig {
1267 custom_headers: Some(custom_headers),
1268 ..Default::default()
1269 }));
1270 let context = Context::owned(
1271 Some("system".to_string()),
1272 vec![Message::User(crate::model::UserMessage {
1273 content: UserContent::Text("ping".to_string()),
1274 timestamp: 0,
1275 })],
1276 Vec::new(),
1277 );
1278 let options = StreamOptions {
1279 api_key: Some("test-cohere-key".to_string()),
1280 ..Default::default()
1281 };
1282
1283 let runtime = RuntimeBuilder::current_thread()
1284 .build()
1285 .expect("runtime build");
1286 runtime.block_on(async {
1287 let mut stream = provider.stream(&context, &options).await.expect("stream");
1288 while let Some(event) = stream.next().await {
1289 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1290 break;
1291 }
1292 }
1293 });
1294
1295 let captured = rx.recv_timeout(Duration::from_secs(2)).expect("captured");
1296 assert_eq!(
1297 captured.headers.get("authorization").map(String::as_str),
1298 Some("Bearer compat-header")
1299 );
1300 assert_eq!(captured.header_count("authorization"), 1);
1301 }
1302
1303 #[test]
1304 fn test_stream_compat_authorization_header_works_without_api_key() {
1305 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1306 let mut custom_headers = HashMap::new();
1307 custom_headers.insert(
1308 "Authorization".to_string(),
1309 "Bearer compat-header".to_string(),
1310 );
1311 let provider = CohereProvider::new("command-r")
1312 .with_base_url(base_url)
1313 .with_compat(Some(CompatConfig {
1314 custom_headers: Some(custom_headers),
1315 ..Default::default()
1316 }));
1317 let context = Context::owned(
1318 Some("system".to_string()),
1319 vec![Message::User(crate::model::UserMessage {
1320 content: UserContent::Text("ping".to_string()),
1321 timestamp: 0,
1322 })],
1323 Vec::new(),
1324 );
1325
1326 let runtime = RuntimeBuilder::current_thread()
1327 .build()
1328 .expect("runtime build");
1329 runtime.block_on(async {
1330 let mut stream = provider
1331 .stream(&context, &StreamOptions::default())
1332 .await
1333 .expect("stream");
1334 while let Some(event) = stream.next().await {
1335 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1336 break;
1337 }
1338 }
1339 });
1340
1341 let captured = rx.recv_timeout(Duration::from_secs(2)).expect("captured");
1342 assert_eq!(
1343 captured.headers.get("authorization").map(String::as_str),
1344 Some("Bearer compat-header")
1345 );
1346 assert_eq!(captured.header_count("authorization"), 1);
1347 }
1348
1349 fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
1350 let runtime = RuntimeBuilder::current_thread()
1351 .build()
1352 .expect("runtime build");
1353
1354 runtime.block_on(async {
1355 let byte_stream = stream::iter(
1356 events
1357 .iter()
1358 .map(|event| {
1359 format!(
1360 "data: {}\n\n",
1361 serde_json::to_string(event).expect("serialize event")
1362 )
1363 })
1364 .map(|s| Ok(s.into_bytes())),
1365 );
1366
1367 let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1368 let mut state = StreamState::new(
1369 event_source,
1370 "command-r".to_string(),
1371 "cohere-chat".to_string(),
1372 "cohere".to_string(),
1373 );
1374
1375 let mut out = Vec::new();
1376 while let Some(item) = state.event_source.next().await {
1377 let msg = item.expect("SSE event");
1378 state.process_event(&msg.data).expect("process_event");
1379 out.extend(state.pending_events.drain(..));
1380 if state.finished {
1381 break;
1382 }
1383 }
1384 out
1385 })
1386 }
1387
1388 #[derive(Debug)]
1389 struct CapturedRequest {
1390 headers: HashMap<String, String>,
1391 header_lines: Vec<(String, String)>,
1392 body: String,
1393 }
1394
1395 impl CapturedRequest {
1396 fn header_count(&self, name: &str) -> usize {
1397 self.header_lines
1398 .iter()
1399 .filter(|(key, _)| key.eq_ignore_ascii_case(name))
1400 .count()
1401 }
1402 }
1403
1404 fn run_stream_and_capture_headers(
1405 api_key: Option<&str>,
1406 extra_headers: HashMap<String, String>,
1407 ) -> Option<CapturedRequest> {
1408 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1409 let provider = CohereProvider::new("command-r").with_base_url(base_url);
1410 let context = Context::owned(
1411 Some("system".to_string()),
1412 vec![Message::User(crate::model::UserMessage {
1413 content: UserContent::Text("ping".to_string()),
1414 timestamp: 0,
1415 })],
1416 Vec::new(),
1417 );
1418 let options = StreamOptions {
1419 api_key: api_key.map(str::to_string),
1420 headers: extra_headers,
1421 ..Default::default()
1422 };
1423
1424 let runtime = RuntimeBuilder::current_thread()
1425 .build()
1426 .expect("runtime build");
1427 runtime.block_on(async {
1428 let mut stream = provider.stream(&context, &options).await.expect("stream");
1429 while let Some(event) = stream.next().await {
1430 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1431 break;
1432 }
1433 }
1434 });
1435
1436 rx.recv_timeout(Duration::from_secs(2)).ok()
1437 }
1438
1439 fn success_sse_body() -> String {
1440 [
1441 r#"data: {"type":"message-start","id":"msg_1"}"#,
1442 "",
1443 r#"data: {"type":"message-end","delta":{"finish_reason":"COMPLETE","usage":{"tokens":{"input_tokens":1,"output_tokens":1}}}}"#,
1444 "",
1445 ]
1446 .join("\n")
1447 }
1448
1449 fn spawn_test_server(
1450 status_code: u16,
1451 content_type: &str,
1452 body: &str,
1453 ) -> (String, mpsc::Receiver<CapturedRequest>) {
1454 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
1455 let addr = listener.local_addr().expect("local addr");
1456 let (tx, rx) = mpsc::channel();
1457 let body = body.to_string();
1458 let content_type = content_type.to_string();
1459
1460 std::thread::spawn(move || {
1461 let (mut socket, _) = listener.accept().expect("accept");
1462 socket
1463 .set_read_timeout(Some(Duration::from_secs(2)))
1464 .expect("set read timeout");
1465
1466 let mut bytes = Vec::new();
1467 let mut chunk = [0_u8; 4096];
1468 loop {
1469 match socket.read(&mut chunk) {
1470 Ok(0) => break,
1471 Ok(n) => {
1472 bytes.extend_from_slice(&chunk[..n]);
1473 if bytes.windows(4).any(|window| window == b"\r\n\r\n") {
1474 break;
1475 }
1476 }
1477 Err(err)
1478 if err.kind() == std::io::ErrorKind::WouldBlock
1479 || err.kind() == std::io::ErrorKind::TimedOut =>
1480 {
1481 break;
1482 }
1483 Err(err) => panic!(),
1484 }
1485 }
1486
1487 let header_end = bytes
1488 .windows(4)
1489 .position(|window| window == b"\r\n\r\n")
1490 .expect("request header boundary");
1491 let header_text = String::from_utf8_lossy(&bytes[..header_end]).to_string();
1492 let (headers, header_lines) = parse_headers(&header_text);
1493 let mut request_body = bytes[header_end + 4..].to_vec();
1494
1495 let content_length = headers
1496 .get("content-length")
1497 .and_then(|value| value.parse::<usize>().ok())
1498 .unwrap_or(0);
1499 while request_body.len() < content_length {
1500 match socket.read(&mut chunk) {
1501 Ok(0) => break,
1502 Ok(n) => request_body.extend_from_slice(&chunk[..n]),
1503 Err(err)
1504 if err.kind() == std::io::ErrorKind::WouldBlock
1505 || err.kind() == std::io::ErrorKind::TimedOut =>
1506 {
1507 break;
1508 }
1509 Err(err) => panic!(),
1510 }
1511 }
1512
1513 let captured = CapturedRequest {
1514 headers,
1515 header_lines,
1516 body: String::from_utf8_lossy(&request_body).to_string(),
1517 };
1518 tx.send(captured).expect("send captured request");
1519
1520 let reason = match status_code {
1521 401 => "Unauthorized",
1522 500 => "Internal Server Error",
1523 _ => "OK",
1524 };
1525 let response = format!(
1526 "HTTP/1.1 {status_code} {reason}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
1527 body.len()
1528 );
1529 socket
1530 .write_all(response.as_bytes())
1531 .expect("write response");
1532 socket.flush().expect("flush response");
1533 });
1534
1535 (format!("http://{addr}"), rx)
1536 }
1537
1538 fn parse_headers(header_text: &str) -> (HashMap<String, String>, Vec<(String, String)>) {
1539 let mut headers = HashMap::new();
1540 let mut header_lines = Vec::new();
1541 for line in header_text.lines().skip(1) {
1542 if let Some((name, value)) = line.split_once(':') {
1543 let normalized_name = name.trim().to_ascii_lowercase();
1544 let normalized_value = value.trim().to_string();
1545 header_lines.push((normalized_name.clone(), normalized_value.clone()));
1546 headers.insert(normalized_name, normalized_value);
1547 }
1548 }
1549 (headers, header_lines)
1550 }
1551
1552 #[test]
1555 fn test_build_request_no_system_prompt() {
1556 let provider = CohereProvider::new("command-r-plus");
1557 let context = Context::owned(
1558 None,
1559 vec![Message::User(crate::model::UserMessage {
1560 content: UserContent::Text("Hi".to_string()),
1561 timestamp: 0,
1562 })],
1563 vec![],
1564 );
1565 let options = StreamOptions::default();
1566
1567 let req = provider.build_request(&context, &options);
1568 let value = serde_json::to_value(&req).expect("serialize");
1569
1570 assert_eq!(value["messages"][0]["role"], "user");
1572 assert_eq!(value["messages"][0]["content"], "Hi");
1573 let msgs = value["messages"].as_array().unwrap();
1575 assert!(
1576 !msgs.iter().any(|m| m["role"] == "system"),
1577 "No system message should be present"
1578 );
1579 }
1580
1581 #[test]
1582 fn test_build_request_default_max_tokens() {
1583 let provider = CohereProvider::new("command-r");
1584 let context = Context::owned(
1585 None,
1586 vec![Message::User(crate::model::UserMessage {
1587 content: UserContent::Text("test".to_string()),
1588 timestamp: 0,
1589 })],
1590 vec![],
1591 );
1592 let options = StreamOptions::default();
1593
1594 let req = provider.build_request(&context, &options);
1595 let value = serde_json::to_value(&req).expect("serialize");
1596
1597 assert_eq!(value["max_tokens"], DEFAULT_MAX_TOKENS);
1598 }
1599
1600 #[test]
1601 fn test_build_request_no_tools_omits_tools_field() {
1602 let provider = CohereProvider::new("command-r");
1603 let context = Context::owned(
1604 None,
1605 vec![Message::User(crate::model::UserMessage {
1606 content: UserContent::Text("test".to_string()),
1607 timestamp: 0,
1608 })],
1609 vec![],
1610 );
1611 let options = StreamOptions::default();
1612
1613 let req = provider.build_request(&context, &options);
1614 let value = serde_json::to_value(&req).expect("serialize");
1615
1616 assert!(
1617 value.get("tools").is_none() || value["tools"].is_null(),
1618 "tools field should be omitted when empty"
1619 );
1620 }
1621
1622 #[test]
1623 fn test_build_request_full_conversation_with_tool_call_and_result() {
1624 let provider = CohereProvider::new("command-r");
1625 let context = Context::owned(
1626 Some("Be concise.".to_string()),
1627 vec![
1628 Message::User(crate::model::UserMessage {
1629 content: UserContent::Text("Read /tmp/a.txt".to_string()),
1630 timestamp: 0,
1631 }),
1632 Message::assistant(AssistantMessage {
1633 content: vec![ContentBlock::ToolCall(ToolCall {
1634 id: "call_1".to_string(),
1635 name: "read".to_string(),
1636 arguments: serde_json::json!({"path": "/tmp/a.txt"}),
1637 thought_signature: None,
1638 })],
1639 api: "cohere-chat".to_string(),
1640 provider: "cohere".to_string(),
1641 model: "command-r".to_string(),
1642 usage: Usage::default(),
1643 stop_reason: StopReason::ToolUse,
1644 error_message: None,
1645 timestamp: 1,
1646 }),
1647 Message::tool_result(crate::model::ToolResultMessage {
1648 tool_call_id: "call_1".to_string(),
1649 tool_name: "read".to_string(),
1650 content: vec![ContentBlock::Text(TextContent::new("file contents"))],
1651 details: None,
1652 is_error: false,
1653 timestamp: 2,
1654 }),
1655 ],
1656 vec![ToolDef {
1657 name: "read".to_string(),
1658 description: "Read a file".to_string(),
1659 parameters: json!({"type": "object"}),
1660 }],
1661 );
1662 let options = StreamOptions::default();
1663
1664 let req = provider.build_request(&context, &options);
1665 let value = serde_json::to_value(&req).expect("serialize");
1666
1667 let msgs = value["messages"].as_array().unwrap();
1668 assert_eq!(msgs.len(), 4);
1670 assert_eq!(msgs[0]["role"], "system");
1671 assert_eq!(msgs[1]["role"], "user");
1672 assert_eq!(msgs[2]["role"], "assistant");
1673 assert_eq!(msgs[3]["role"], "tool");
1674
1675 assert!(msgs[2].get("content").is_none() || msgs[2]["content"].is_null());
1677 let tool_calls = msgs[2]["tool_calls"].as_array().unwrap();
1678 assert_eq!(tool_calls.len(), 1);
1679 assert_eq!(tool_calls[0]["id"], "call_1");
1680 assert_eq!(tool_calls[0]["type"], "function");
1681 assert_eq!(tool_calls[0]["function"]["name"], "read");
1682
1683 assert_eq!(msgs[3]["tool_call_id"], "call_1");
1685 assert_eq!(msgs[3]["content"], "file contents");
1686 }
1687
1688 #[test]
1689 fn test_build_request_assistant_text_preserved_alongside_tool_calls() {
1690 let provider = CohereProvider::new("command-r");
1691 let context = Context::owned(
1692 None,
1693 vec![Message::assistant(AssistantMessage {
1694 content: vec![
1695 ContentBlock::Text(TextContent::new("Let me read that file.")),
1696 ContentBlock::ToolCall(ToolCall {
1697 id: "call_1".to_string(),
1698 name: "read".to_string(),
1699 arguments: json!({"path": "/tmp/a.txt"}),
1700 thought_signature: None,
1701 }),
1702 ],
1703 api: "cohere-chat".to_string(),
1704 provider: "cohere".to_string(),
1705 model: "command-r".to_string(),
1706 usage: Usage::default(),
1707 stop_reason: StopReason::ToolUse,
1708 error_message: None,
1709 timestamp: 0,
1710 })],
1711 vec![],
1712 );
1713 let options = StreamOptions::default();
1714
1715 let req = provider.build_request(&context, &options);
1716 let value = serde_json::to_value(&req).expect("serialize");
1717 let msgs = value["messages"].as_array().unwrap();
1718
1719 assert_eq!(msgs[0]["role"], "assistant");
1720 assert_eq!(
1721 msgs[0]["content"].as_str(),
1722 Some("Let me read that file."),
1723 "Assistant text must be preserved when tool_calls are also present"
1724 );
1725 assert_eq!(msgs[0]["tool_calls"].as_array().unwrap().len(), 1);
1726 }
1727
1728 #[test]
1729 fn test_convert_custom_message_to_cohere() {
1730 let context = Context::owned(
1731 None,
1732 vec![Message::Custom(crate::model::CustomMessage {
1733 custom_type: "extension_note".to_string(),
1734 content: "Important context.".to_string(),
1735 display: false,
1736 details: None,
1737 timestamp: 0,
1738 })],
1739 vec![],
1740 );
1741
1742 let msgs = build_cohere_messages(&context);
1743 assert_eq!(msgs.len(), 1);
1744 let value = serde_json::to_value(&msgs[0]).expect("serialize");
1745 assert_eq!(value["role"], "user");
1746 assert_eq!(value["content"], "Important context.");
1747 }
1748
1749 #[test]
1750 fn test_convert_user_blocks_extracts_text_only() {
1751 let content = UserContent::Blocks(vec![
1752 ContentBlock::Text(TextContent::new("part 1")),
1753 ContentBlock::Image(crate::model::ImageContent {
1754 data: "aGVsbG8=".to_string(),
1755 mime_type: "image/png".to_string(),
1756 }),
1757 ContentBlock::Text(TextContent::new("part 2")),
1758 ]);
1759
1760 let text = extract_text_user_content(&content);
1761 assert_eq!(text, "part 1[Image: image/png (8 bytes)]part 2");
1762 }
1763
1764 #[test]
1767 fn test_custom_provider_name() {
1768 let provider = CohereProvider::new("command-r").with_provider_name("my-proxy");
1769 assert_eq!(provider.name(), "my-proxy");
1770 assert_eq!(provider.api(), "cohere-chat");
1771 }
1772
1773 #[test]
1774 fn test_custom_base_url() {
1775 let provider =
1776 CohereProvider::new("command-r").with_base_url("https://proxy.example.com/v2/chat");
1777 assert_eq!(provider.base_url, "https://proxy.example.com/v2/chat");
1778 }
1779
1780 #[test]
1783 fn test_stream_complete_finish_reason_maps_to_stop() {
1784 let events = vec![
1785 json!({ "type": "message-start", "id": "msg_1" }),
1786 json!({
1787 "type": "message-end",
1788 "delta": {
1789 "finish_reason": "COMPLETE",
1790 "usage": { "tokens": { "input_tokens": 5, "output_tokens": 10 } }
1791 }
1792 }),
1793 ];
1794
1795 let out = collect_events(&events);
1796 assert!(out.iter().any(|e| matches!(
1797 e,
1798 StreamEvent::Done {
1799 reason: StopReason::Stop,
1800 message,
1801 ..
1802 } if message.usage.input == 5 && message.usage.output == 10
1803 )));
1804 }
1805
1806 #[test]
1807 fn test_stream_error_finish_reason_maps_to_error() {
1808 let events = vec![
1809 json!({ "type": "message-start", "id": "msg_1" }),
1810 json!({
1811 "type": "message-end",
1812 "delta": {
1813 "finish_reason": "ERROR",
1814 "usage": { "tokens": { "input_tokens": 1, "output_tokens": 0 } }
1815 }
1816 }),
1817 ];
1818
1819 let out = collect_events(&events);
1820 assert!(out.iter().any(|e| matches!(
1821 e,
1822 StreamEvent::Done {
1823 reason: StopReason::Error,
1824 ..
1825 }
1826 )));
1827 }
1828
1829 #[test]
1830 fn test_stream_tool_call_with_streamed_arguments() {
1831 let events = vec![
1832 json!({ "type": "message-start", "id": "msg_1" }),
1833 json!({
1834 "type": "tool-call-start",
1835 "delta": {
1836 "message": {
1837 "tool_calls": {
1838 "id": "call_42",
1839 "type": "function",
1840 "function": { "name": "bash", "arguments": "{\"co" }
1841 }
1842 }
1843 }
1844 }),
1845 json!({
1846 "type": "tool-call-delta",
1847 "delta": {
1848 "message": {
1849 "tool_calls": {
1850 "function": { "arguments": "mmand\"" }
1851 }
1852 }
1853 }
1854 }),
1855 json!({
1856 "type": "tool-call-delta",
1857 "delta": {
1858 "message": {
1859 "tool_calls": {
1860 "function": { "arguments": ": \"ls -la\"}" }
1861 }
1862 }
1863 }
1864 }),
1865 json!({ "type": "tool-call-end" }),
1866 json!({
1867 "type": "message-end",
1868 "delta": {
1869 "finish_reason": "TOOL_CALL",
1870 "usage": { "tokens": { "input_tokens": 10, "output_tokens": 20 } }
1871 }
1872 }),
1873 ];
1874
1875 let out = collect_events(&events);
1876
1877 let tool_end = out
1879 .iter()
1880 .find(|e| matches!(e, StreamEvent::ToolCallEnd { .. }));
1881 assert!(tool_end.is_some(), "Expected ToolCallEnd event");
1882 if let Some(StreamEvent::ToolCallEnd { tool_call, .. }) = tool_end {
1883 assert_eq!(tool_call.name, "bash");
1884 assert_eq!(tool_call.id, "call_42");
1885 assert_eq!(tool_call.arguments["command"], "ls -la");
1886 }
1887 }
1888
1889 #[test]
1890 fn test_stream_unknown_event_type_ignored() {
1891 let events = vec![
1892 json!({ "type": "message-start", "id": "msg_1" }),
1893 json!({ "type": "some-future-event", "data": "ignored" }),
1894 json!({
1895 "type": "content-start",
1896 "index": 0,
1897 "delta": { "message": { "content": { "type": "text", "text": "OK" } } }
1898 }),
1899 json!({ "type": "content-end", "index": 0 }),
1900 json!({
1901 "type": "message-end",
1902 "delta": {
1903 "finish_reason": "COMPLETE",
1904 "usage": { "tokens": { "input_tokens": 1, "output_tokens": 1 } }
1905 }
1906 }),
1907 ];
1908
1909 let out = collect_events(&events);
1910 assert!(out.iter().any(|e| matches!(e, StreamEvent::Done { .. })));
1912 assert!(out.iter().any(|e| matches!(
1913 e,
1914 StreamEvent::TextStart {
1915 content_index: 0,
1916 ..
1917 }
1918 )));
1919 }
1920}
1921
1922#[cfg(feature = "fuzzing")]
1927pub mod fuzz {
1928 use super::*;
1929 use futures::stream;
1930 use std::pin::Pin;
1931
1932 type FuzzStream =
1933 Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
1934
1935 pub struct Processor(StreamState<FuzzStream>);
1937
1938 impl Default for Processor {
1939 fn default() -> Self {
1940 Self::new()
1941 }
1942 }
1943
1944 impl Processor {
1945 pub fn new() -> Self {
1947 let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1948 Self(StreamState::new(
1949 crate::sse::SseStream::new(Box::pin(empty)),
1950 "cohere-fuzz".into(),
1951 "cohere".into(),
1952 "cohere".into(),
1953 ))
1954 }
1955
1956 pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
1958 self.0.process_event(data)?;
1959 Ok(self.0.pending_events.drain(..).collect())
1960 }
1961 }
1962}