1use crate::error::{Error, Result};
7use crate::http::client::Client;
8use crate::model::{
9 AssistantMessage, ContentBlock, Message, StopReason, StreamEvent, TextContent, ThinkingContent,
10 ToolCall, Usage, UserContent,
11};
12use crate::models::CompatConfig;
13use crate::provider::{Context, Provider, StreamOptions, ToolDef};
14use crate::sse::SseStream;
15use async_trait::async_trait;
16use futures::StreamExt;
17use futures::stream::{self, Stream};
18use serde::{Deserialize, Serialize};
19use std::collections::{HashMap, VecDeque};
20use std::pin::Pin;
21
22const COHERE_CHAT_API_URL: &str = "https://api.cohere.com/v2/chat";
27const DEFAULT_MAX_TOKENS: u32 = 4096;
28
29pub struct CohereProvider {
35 client: Client,
36 model: String,
37 base_url: String,
38 provider: String,
39 compat: Option<CompatConfig>,
40}
41
42impl CohereProvider {
43 pub fn new(model: impl Into<String>) -> Self {
44 Self {
45 client: Client::new(),
46 model: model.into(),
47 base_url: COHERE_CHAT_API_URL.to_string(),
48 provider: "cohere".to_string(),
49 compat: None,
50 }
51 }
52
53 #[must_use]
54 pub fn with_provider_name(mut self, provider: impl Into<String>) -> Self {
55 self.provider = provider.into();
56 self
57 }
58
59 #[must_use]
60 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
61 self.base_url = base_url.into();
62 self
63 }
64
65 #[must_use]
66 pub fn with_client(mut self, client: Client) -> Self {
67 self.client = client;
68 self
69 }
70
71 #[must_use]
73 pub fn with_compat(mut self, compat: Option<CompatConfig>) -> Self {
74 self.compat = compat;
75 self
76 }
77
78 pub fn build_request(&self, context: &Context<'_>, options: &StreamOptions) -> CohereRequest {
79 let messages = build_cohere_messages(context);
80
81 let tools: Option<Vec<CohereTool>> = if context.tools.is_empty() {
82 None
83 } else {
84 Some(context.tools.iter().map(convert_tool_to_cohere).collect())
85 };
86
87 CohereRequest {
88 model: self.model.clone(),
89 messages,
90 max_tokens: options.max_tokens.or(Some(DEFAULT_MAX_TOKENS)),
91 temperature: options.temperature,
92 tools,
93 stream: true,
94 }
95 }
96}
97
98#[async_trait]
99#[allow(clippy::too_many_lines)]
100impl Provider for CohereProvider {
101 fn name(&self) -> &str {
102 &self.provider
103 }
104
105 fn api(&self) -> &'static str {
106 "cohere-chat"
107 }
108
109 fn model_id(&self) -> &str {
110 &self.model
111 }
112
113 async fn stream(
114 &self,
115 context: &Context<'_>,
116 options: &StreamOptions,
117 ) -> Result<Pin<Box<dyn Stream<Item = Result<StreamEvent>> + Send>>> {
118 let has_authorization_header = options
119 .headers
120 .keys()
121 .any(|key| key.eq_ignore_ascii_case("authorization"));
122
123 let auth_value = if has_authorization_header {
124 None
125 } else {
126 Some(
127 options
128 .api_key
129 .clone()
130 .or_else(|| std::env::var("COHERE_API_KEY").ok())
131 .ok_or_else(|| Error::provider("cohere", "Missing API key for Cohere. Set COHERE_API_KEY or configure in settings."))?,
132 )
133 };
134
135 let request_body = self.build_request(context, options);
136
137 let mut request = self
139 .client
140 .post(&self.base_url)
141 .header("Accept", "text/event-stream");
142
143 if let Some(auth_value) = auth_value {
144 request = request.header("Authorization", format!("Bearer {auth_value}"));
145 }
146
147 if let Some(compat) = &self.compat {
149 if let Some(custom_headers) = &compat.custom_headers {
150 for (key, value) in custom_headers {
151 request = request.header(key, value);
152 }
153 }
154 }
155
156 for (key, value) in &options.headers {
158 request = request.header(key, value);
159 }
160
161 let request = request.json(&request_body)?;
162
163 let response = Box::pin(request.send()).await?;
164 let status = response.status();
165 if !(200..300).contains(&status) {
166 let body = response
167 .text()
168 .await
169 .unwrap_or_else(|e| format!("<failed to read body: {e}>"));
170 return Err(Error::provider(
171 "cohere",
172 format!("Cohere API error (HTTP {status}): {body}"),
173 ));
174 }
175
176 let content_type = response
177 .headers()
178 .iter()
179 .find(|(name, _)| name.eq_ignore_ascii_case("content-type"))
180 .map(|(_, value)| value.to_ascii_lowercase());
181 if !content_type
182 .as_deref()
183 .is_some_and(|value| value.contains("text/event-stream"))
184 {
185 let message = content_type.map_or_else(
186 || {
187 format!(
188 "Cohere API protocol error (HTTP {status}): missing Content-Type (expected text/event-stream)"
189 )
190 },
191 |value| {
192 format!(
193 "Cohere API protocol error (HTTP {status}): unexpected Content-Type {value} (expected text/event-stream)"
194 )
195 },
196 );
197 return Err(Error::api(message));
198 }
199
200 let event_source = SseStream::new(response.bytes_stream());
201
202 let model = self.model.clone();
203 let api = self.api().to_string();
204 let provider = self.name().to_string();
205
206 let stream = stream::unfold(
207 StreamState::new(event_source, model, api, provider),
208 |mut state| async move {
209 loop {
210 if let Some(event) = state.pending_events.pop_front() {
211 return Some((Ok(event), state));
212 }
213
214 if state.finished {
215 return None;
216 }
217
218 match state.event_source.next().await {
219 Some(Ok(msg)) => {
220 if msg.data == "[DONE]" {
221 state.finish();
222 continue;
223 }
224
225 if let Err(e) = state.process_event(&msg.data) {
226 return Some((Err(e), state));
227 }
228 }
229 Some(Err(e)) => {
230 let err = Error::api(format!("SSE error: {e}"));
231 return Some((Err(err), state));
232 }
233 None => {
234 return Some((
236 Err(Error::api("Stream ended without Done event")),
237 state,
238 ));
239 }
240 }
241 }
242 },
243 );
244
245 Ok(Box::pin(stream))
246 }
247}
248
249struct ToolCallAccum {
254 content_index: usize,
255 id: String,
256 name: String,
257 arguments: String,
258}
259
260struct StreamState<S>
261where
262 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
263{
264 event_source: SseStream<S>,
265 partial: AssistantMessage,
266 pending_events: VecDeque<StreamEvent>,
267 started: bool,
268 finished: bool,
269 content_index_map: HashMap<u32, usize>,
270 active_tool_call: Option<ToolCallAccum>,
271}
272
273impl<S> StreamState<S>
274where
275 S: Stream<Item = std::result::Result<Vec<u8>, std::io::Error>> + Unpin,
276{
277 fn new(event_source: SseStream<S>, model: String, api: String, provider: String) -> Self {
278 Self {
279 event_source,
280 partial: AssistantMessage {
281 content: Vec::new(),
282 api,
283 provider,
284 model,
285 usage: Usage::default(),
286 stop_reason: StopReason::Stop,
287 error_message: None,
288 timestamp: chrono::Utc::now().timestamp_millis(),
289 },
290 pending_events: VecDeque::new(),
291 started: false,
292 finished: false,
293 content_index_map: HashMap::new(),
294 active_tool_call: None,
295 }
296 }
297
298 fn ensure_started(&mut self) {
299 if !self.started {
300 self.started = true;
301 self.pending_events.push_back(StreamEvent::Start {
302 partial: self.partial.clone(),
303 });
304 }
305 }
306
307 fn content_block_for(&mut self, idx: u32, kind: CohereContentKind) -> usize {
308 if let Some(existing) = self.content_index_map.get(&idx) {
309 return *existing;
310 }
311
312 let content_index = self.partial.content.len();
313 match kind {
314 CohereContentKind::Text => {
315 self.partial
316 .content
317 .push(ContentBlock::Text(TextContent::new("")));
318 self.pending_events
319 .push_back(StreamEvent::TextStart { content_index });
320 }
321 CohereContentKind::Thinking => {
322 self.partial
323 .content
324 .push(ContentBlock::Thinking(ThinkingContent {
325 thinking: String::new(),
326 thinking_signature: None,
327 }));
328 self.pending_events
329 .push_back(StreamEvent::ThinkingStart { content_index });
330 }
331 }
332
333 self.content_index_map.insert(idx, content_index);
334 content_index
335 }
336
337 #[allow(clippy::too_many_lines)]
338 fn process_event(&mut self, data: &str) -> Result<()> {
339 let chunk: CohereStreamChunk =
340 serde_json::from_str(data).map_err(|e| Error::api(format!("JSON parse error: {e}")))?;
341
342 match chunk {
343 CohereStreamChunk::MessageStart { .. } => {
344 self.ensure_started();
345 }
346 CohereStreamChunk::ContentStart { index, delta } => {
347 self.ensure_started();
348 let (kind, initial) = delta.message.content.kind_and_text();
349 let content_index = self.content_block_for(index, kind);
350
351 if !initial.is_empty() {
352 match kind {
353 CohereContentKind::Text => {
354 if let Some(ContentBlock::Text(t)) =
355 self.partial.content.get_mut(content_index)
356 {
357 t.text.push_str(&initial);
358 }
359 self.pending_events.push_back(StreamEvent::TextDelta {
360 content_index,
361 delta: initial,
362 });
363 }
364 CohereContentKind::Thinking => {
365 if let Some(ContentBlock::Thinking(t)) =
366 self.partial.content.get_mut(content_index)
367 {
368 t.thinking.push_str(&initial);
369 }
370 self.pending_events.push_back(StreamEvent::ThinkingDelta {
371 content_index,
372 delta: initial,
373 });
374 }
375 }
376 }
377 }
378 CohereStreamChunk::ContentDelta { index, delta } => {
379 self.ensure_started();
380 let (kind, delta_text) = delta.message.content.kind_and_text();
381 let content_index = self.content_block_for(index, kind);
382
383 match kind {
384 CohereContentKind::Text => {
385 if let Some(ContentBlock::Text(t)) =
386 self.partial.content.get_mut(content_index)
387 {
388 t.text.push_str(&delta_text);
389 }
390 self.pending_events.push_back(StreamEvent::TextDelta {
391 content_index,
392 delta: delta_text,
393 });
394 }
395 CohereContentKind::Thinking => {
396 if let Some(ContentBlock::Thinking(t)) =
397 self.partial.content.get_mut(content_index)
398 {
399 t.thinking.push_str(&delta_text);
400 }
401 self.pending_events.push_back(StreamEvent::ThinkingDelta {
402 content_index,
403 delta: delta_text,
404 });
405 }
406 }
407 }
408 CohereStreamChunk::ContentEnd { index } => {
409 if let Some(content_index) = self.content_index_map.get(&index).copied() {
410 match self.partial.content.get(content_index) {
411 Some(ContentBlock::Text(t)) => {
412 self.pending_events.push_back(StreamEvent::TextEnd {
413 content_index,
414 content: t.text.clone(),
415 });
416 }
417 Some(ContentBlock::Thinking(t)) => {
418 self.pending_events.push_back(StreamEvent::ThinkingEnd {
419 content_index,
420 content: t.thinking.clone(),
421 });
422 }
423 _ => {}
424 }
425 }
426 }
427 CohereStreamChunk::ToolCallStart { delta } => {
428 self.ensure_started();
429 let tc = delta.message.tool_calls;
430 let content_index = self.partial.content.len();
431 self.partial.content.push(ContentBlock::ToolCall(ToolCall {
432 id: tc.id.clone(),
433 name: tc.function.name.clone(),
434 arguments: serde_json::Value::Null,
435 thought_signature: None,
436 }));
437
438 self.active_tool_call = Some(ToolCallAccum {
439 content_index,
440 id: tc.id,
441 name: tc.function.name,
442 arguments: tc.function.arguments.clone(),
443 });
444
445 self.pending_events
446 .push_back(StreamEvent::ToolCallStart { content_index });
447 if !tc.function.arguments.is_empty() {
448 self.pending_events.push_back(StreamEvent::ToolCallDelta {
449 content_index,
450 delta: tc.function.arguments,
451 });
452 }
453 }
454 CohereStreamChunk::ToolCallDelta { delta } => {
455 self.ensure_started();
456 if let Some(active) = self.active_tool_call.as_mut() {
457 active
458 .arguments
459 .push_str(&delta.message.tool_calls.function.arguments);
460 self.pending_events.push_back(StreamEvent::ToolCallDelta {
461 content_index: active.content_index,
462 delta: delta.message.tool_calls.function.arguments,
463 });
464 }
465 }
466 CohereStreamChunk::ToolCallEnd => {
467 if let Some(active) = self.active_tool_call.take() {
468 self.ensure_started();
469 let parsed_args: serde_json::Value = serde_json::from_str(&active.arguments)
470 .unwrap_or_else(|e| {
471 tracing::warn!(
472 error = %e,
473 raw = %active.arguments,
474 "Failed to parse tool arguments as JSON"
475 );
476 serde_json::Value::Null
477 });
478
479 self.partial.stop_reason = StopReason::ToolUse;
480 self.pending_events.push_back(StreamEvent::ToolCallEnd {
481 content_index: active.content_index,
482 tool_call: ToolCall {
483 id: active.id,
484 name: active.name,
485 arguments: parsed_args.clone(),
486 thought_signature: None,
487 },
488 });
489
490 if let Some(ContentBlock::ToolCall(block)) =
491 self.partial.content.get_mut(active.content_index)
492 {
493 block.arguments = parsed_args;
494 }
495 }
496 }
497 CohereStreamChunk::MessageEnd { delta } => {
498 self.ensure_started();
499 self.partial.usage.input = delta.usage.tokens.input_tokens;
500 self.partial.usage.output = delta.usage.tokens.output_tokens;
501 self.partial.usage.total_tokens =
502 delta.usage.tokens.input_tokens + delta.usage.tokens.output_tokens;
503
504 self.partial.stop_reason = match delta.finish_reason.as_str() {
505 "MAX_TOKENS" => StopReason::Length,
506 "TOOL_CALL" => StopReason::ToolUse,
507 "ERROR" => StopReason::Error,
508 _ => StopReason::Stop,
509 };
510
511 self.finish();
512 }
513 CohereStreamChunk::Unknown => {}
514 }
515
516 Ok(())
517 }
518
519 fn finish(&mut self) {
520 if self.finished {
521 return;
522 }
523 let reason = self.partial.stop_reason;
524 self.pending_events.push_back(StreamEvent::Done {
525 reason,
526 message: std::mem::take(&mut self.partial),
527 });
528 self.finished = true;
529 }
530}
531
532#[derive(Debug, Serialize)]
537pub struct CohereRequest {
538 model: String,
539 messages: Vec<CohereMessage>,
540 #[serde(skip_serializing_if = "Option::is_none")]
541 max_tokens: Option<u32>,
542 #[serde(skip_serializing_if = "Option::is_none")]
543 temperature: Option<f32>,
544 #[serde(skip_serializing_if = "Option::is_none")]
545 tools: Option<Vec<CohereTool>>,
546 stream: bool,
547}
548
549#[derive(Debug, Serialize)]
550#[serde(tag = "role", rename_all = "lowercase")]
551enum CohereMessage {
552 System {
553 content: String,
554 },
555 User {
556 content: String,
557 },
558 Assistant {
559 #[serde(skip_serializing_if = "Option::is_none")]
560 content: Option<String>,
561 #[serde(skip_serializing_if = "Option::is_none")]
562 tool_calls: Option<Vec<CohereToolCallRef>>,
563 #[serde(skip_serializing_if = "Option::is_none")]
564 tool_plan: Option<String>,
565 },
566 Tool {
567 content: String,
568 tool_call_id: String,
569 },
570}
571
572#[derive(Debug, Serialize)]
573struct CohereToolCallRef {
574 id: String,
575 #[serde(rename = "type")]
576 r#type: &'static str,
577 function: CohereFunctionRef,
578}
579
580#[derive(Debug, Serialize)]
581struct CohereFunctionRef {
582 name: String,
583 arguments: String,
584}
585
586#[derive(Debug, Serialize)]
587struct CohereTool {
588 #[serde(rename = "type")]
589 r#type: &'static str,
590 function: CohereFunction,
591}
592
593#[derive(Debug, Serialize)]
594struct CohereFunction {
595 name: String,
596 #[serde(skip_serializing_if = "Option::is_none")]
597 description: Option<String>,
598 parameters: serde_json::Value,
599}
600
601fn convert_tool_to_cohere(tool: &ToolDef) -> CohereTool {
602 CohereTool {
603 r#type: "function",
604 function: CohereFunction {
605 name: tool.name.clone(),
606 description: if tool.description.trim().is_empty() {
607 None
608 } else {
609 Some(tool.description.clone())
610 },
611 parameters: tool.parameters.clone(),
612 },
613 }
614}
615
616fn build_cohere_messages(context: &Context<'_>) -> Vec<CohereMessage> {
617 let mut out = Vec::new();
618
619 if let Some(system) = &context.system_prompt {
620 out.push(CohereMessage::System {
621 content: system.to_string(),
622 });
623 }
624
625 for message in context.messages.iter() {
626 match message {
627 Message::User(user) => out.push(CohereMessage::User {
628 content: extract_text_user_content(&user.content),
629 }),
630 Message::Custom(custom) => out.push(CohereMessage::User {
631 content: custom.content.clone(),
632 }),
633 Message::Assistant(assistant) => {
634 let mut text = String::new();
635 let mut tool_calls = Vec::new();
636
637 for block in &assistant.content {
638 match block {
639 ContentBlock::Text(t) => text.push_str(&t.text),
640 ContentBlock::ToolCall(tc) => tool_calls.push(CohereToolCallRef {
641 id: tc.id.clone(),
642 r#type: "function",
643 function: CohereFunctionRef {
644 name: tc.name.clone(),
645 arguments: tc.arguments.to_string(),
646 },
647 }),
648 _ => {}
649 }
650 }
651
652 out.push(CohereMessage::Assistant {
653 content: if tool_calls.is_empty() && !text.is_empty() {
654 Some(text)
655 } else {
656 None
657 },
658 tool_calls: if tool_calls.is_empty() {
659 None
660 } else {
661 Some(tool_calls)
662 },
663 tool_plan: None,
664 });
665 }
666 Message::ToolResult(result) => {
667 let mut content = String::new();
668 for (i, block) in result.content.iter().enumerate() {
669 if i > 0 {
670 content.push('\n');
671 }
672 if let ContentBlock::Text(t) = block {
673 content.push_str(&t.text);
674 }
675 }
676 out.push(CohereMessage::Tool {
677 content,
678 tool_call_id: result.tool_call_id.clone(),
679 });
680 }
681 }
682 }
683
684 out
685}
686
687fn extract_text_user_content(content: &UserContent) -> String {
688 match content {
689 UserContent::Text(text) => text.clone(),
690 UserContent::Blocks(blocks) => {
691 let mut out = String::new();
692 for block in blocks {
693 match block {
694 ContentBlock::Text(t) => out.push_str(&t.text),
695 ContentBlock::Image(img) => {
696 use std::fmt::Write as _;
697 let _ =
698 write!(out, "[Image: {} ({} bytes)]", img.mime_type, img.data.len());
699 }
700 _ => {}
701 }
702 }
703 out
704 }
705 }
706}
707
708#[derive(Debug, Deserialize)]
713#[serde(tag = "type")]
714enum CohereStreamChunk {
715 #[serde(rename = "message-start")]
716 MessageStart { id: Option<String> },
717 #[serde(rename = "content-start")]
718 ContentStart {
719 index: u32,
720 delta: CohereContentStartDelta,
721 },
722 #[serde(rename = "content-delta")]
723 ContentDelta {
724 index: u32,
725 delta: CohereContentDelta,
726 },
727 #[serde(rename = "content-end")]
728 ContentEnd { index: u32 },
729 #[serde(rename = "tool-call-start")]
730 ToolCallStart { delta: CohereToolCallStartDelta },
731 #[serde(rename = "tool-call-delta")]
732 ToolCallDelta { delta: CohereToolCallDelta },
733 #[serde(rename = "tool-call-end")]
734 ToolCallEnd,
735 #[serde(rename = "message-end")]
736 MessageEnd { delta: CohereMessageEndDelta },
737 #[serde(other)]
738 Unknown,
739}
740
741#[derive(Debug, Deserialize)]
742struct CohereContentStartDelta {
743 message: CohereDeltaMessage<CohereContentStart>,
744}
745
746#[derive(Debug, Deserialize)]
747struct CohereContentDelta {
748 message: CohereDeltaMessage<CohereContentDeltaPart>,
749}
750
751#[derive(Debug, Deserialize)]
752struct CohereDeltaMessage<T> {
753 content: T,
754}
755
756#[derive(Debug, Deserialize)]
757#[serde(tag = "type")]
758enum CohereContentStart {
759 #[serde(rename = "text")]
760 Text { text: String },
761 #[serde(rename = "thinking")]
762 Thinking { thinking: String },
763}
764
765#[derive(Debug, Deserialize)]
766#[serde(untagged)]
767enum CohereContentDeltaPart {
768 Text { text: String },
769 Thinking { thinking: String },
770}
771
772#[derive(Debug, Clone, Copy)]
773enum CohereContentKind {
774 Text,
775 Thinking,
776}
777
778impl CohereContentStart {
779 fn kind_and_text(self) -> (CohereContentKind, String) {
780 match self {
781 Self::Text { text } => (CohereContentKind::Text, text),
782 Self::Thinking { thinking } => (CohereContentKind::Thinking, thinking),
783 }
784 }
785}
786
787impl CohereContentDeltaPart {
788 fn kind_and_text(self) -> (CohereContentKind, String) {
789 match self {
790 Self::Text { text } => (CohereContentKind::Text, text),
791 Self::Thinking { thinking } => (CohereContentKind::Thinking, thinking),
792 }
793 }
794}
795
796#[derive(Debug, Deserialize)]
797struct CohereToolCallStartDelta {
798 message: CohereToolCallMessage<CohereToolCallStartBody>,
799}
800
801#[derive(Debug, Deserialize)]
802struct CohereToolCallDelta {
803 message: CohereToolCallMessage<CohereToolCallDeltaBody>,
804}
805
806#[derive(Debug, Deserialize)]
807struct CohereToolCallMessage<T> {
808 tool_calls: T,
809}
810
811#[derive(Debug, Deserialize)]
812struct CohereToolCallStartBody {
813 id: String,
814 function: CohereToolCallFunctionStart,
815}
816
817#[derive(Debug, Deserialize)]
818struct CohereToolCallFunctionStart {
819 name: String,
820 arguments: String,
821}
822
823#[derive(Debug, Deserialize)]
824struct CohereToolCallDeltaBody {
825 function: CohereToolCallFunctionDelta,
826}
827
828#[derive(Debug, Deserialize)]
829struct CohereToolCallFunctionDelta {
830 arguments: String,
831}
832
833#[derive(Debug, Deserialize)]
834struct CohereMessageEndDelta {
835 finish_reason: String,
836 usage: CohereUsage,
837}
838
839#[derive(Debug, Deserialize)]
840struct CohereUsage {
841 tokens: CohereUsageTokens,
842}
843
844#[derive(Debug, Deserialize)]
845struct CohereUsageTokens {
846 input_tokens: u64,
847 output_tokens: u64,
848}
849
850#[cfg(test)]
855mod tests {
856 use super::*;
857 use asupersync::runtime::RuntimeBuilder;
858 use futures::stream;
859 use serde_json::{Value, json};
860 use std::collections::HashMap;
861 use std::io::{Read, Write};
862 use std::net::TcpListener;
863 use std::path::PathBuf;
864 use std::sync::mpsc;
865 use std::time::Duration;
866
867 #[derive(Debug, Deserialize)]
870 struct ProviderFixture {
871 cases: Vec<ProviderCase>,
872 }
873
874 #[derive(Debug, Deserialize)]
875 struct ProviderCase {
876 name: String,
877 events: Vec<Value>,
878 expected: Vec<EventSummary>,
879 }
880
881 #[derive(Debug, Deserialize, Serialize, PartialEq)]
882 struct EventSummary {
883 kind: String,
884 #[serde(default)]
885 content_index: Option<usize>,
886 #[serde(default)]
887 delta: Option<String>,
888 #[serde(default)]
889 content: Option<String>,
890 #[serde(default)]
891 reason: Option<String>,
892 }
893
894 fn load_fixture(file_name: &str) -> ProviderFixture {
895 let path = PathBuf::from(env!("CARGO_MANIFEST_DIR"))
896 .join("tests/fixtures/provider_responses")
897 .join(file_name);
898 let raw = std::fs::read_to_string(path).expect("fixture read");
899 serde_json::from_str(&raw).expect("fixture parse")
900 }
901
902 fn summarize_event(event: &StreamEvent) -> EventSummary {
903 match event {
904 StreamEvent::Start { .. } => EventSummary {
905 kind: "start".to_string(),
906 content_index: None,
907 delta: None,
908 content: None,
909 reason: None,
910 },
911 StreamEvent::TextStart { content_index, .. } => EventSummary {
912 kind: "text_start".to_string(),
913 content_index: Some(*content_index),
914 delta: None,
915 content: None,
916 reason: None,
917 },
918 StreamEvent::TextDelta {
919 content_index,
920 delta,
921 ..
922 } => EventSummary {
923 kind: "text_delta".to_string(),
924 content_index: Some(*content_index),
925 delta: Some(delta.clone()),
926 content: None,
927 reason: None,
928 },
929 StreamEvent::TextEnd {
930 content_index,
931 content,
932 ..
933 } => EventSummary {
934 kind: "text_end".to_string(),
935 content_index: Some(*content_index),
936 delta: None,
937 content: Some(content.clone()),
938 reason: None,
939 },
940 StreamEvent::Done { reason, .. } => EventSummary {
941 kind: "done".to_string(),
942 content_index: None,
943 delta: None,
944 content: None,
945 reason: Some(reason_to_string(*reason)),
946 },
947 StreamEvent::Error { reason, .. } => EventSummary {
948 kind: "error".to_string(),
949 content_index: None,
950 delta: None,
951 content: None,
952 reason: Some(reason_to_string(*reason)),
953 },
954 _ => EventSummary {
955 kind: "other".to_string(),
956 content_index: None,
957 delta: None,
958 content: None,
959 reason: None,
960 },
961 }
962 }
963
964 fn reason_to_string(reason: StopReason) -> String {
965 match reason {
966 StopReason::Stop => "stop",
967 StopReason::Length => "length",
968 StopReason::ToolUse => "tool_use",
969 StopReason::Error => "error",
970 StopReason::Aborted => "aborted",
971 }
972 .to_string()
973 }
974
975 #[test]
976 fn test_stream_fixtures() {
977 let fixture = load_fixture("cohere_stream.json");
978 for case in fixture.cases {
979 let events = collect_events(&case.events);
980 let summaries: Vec<EventSummary> = events.iter().map(summarize_event).collect();
981 assert_eq!(summaries, case.expected, "case {}", case.name);
982 }
983 }
984
985 #[test]
988 fn test_provider_info() {
989 let provider = CohereProvider::new("command-r");
990 assert_eq!(provider.name(), "cohere");
991 assert_eq!(provider.api(), "cohere-chat");
992 }
993
994 #[test]
995 fn test_build_request_includes_system_tools_and_v2_shape() {
996 let provider = CohereProvider::new("command-r");
997 let context = Context::owned(
998 Some("You are concise.".to_string()),
999 vec![Message::User(crate::model::UserMessage {
1000 content: UserContent::Text("Ping".to_string()),
1001 timestamp: 0,
1002 })],
1003 vec![ToolDef {
1004 name: "search".to_string(),
1005 description: "Search docs".to_string(),
1006 parameters: json!({
1007 "type": "object",
1008 "properties": {
1009 "q": { "type": "string" }
1010 },
1011 "required": ["q"]
1012 }),
1013 }],
1014 );
1015 let options = StreamOptions {
1016 temperature: Some(0.2),
1017 max_tokens: Some(123),
1018 ..Default::default()
1019 };
1020
1021 let request = provider.build_request(&context, &options);
1022 let value = serde_json::to_value(&request).expect("serialize request");
1023
1024 assert_eq!(value["model"], "command-r");
1025 assert_eq!(value["messages"][0]["role"], "system");
1026 assert_eq!(value["messages"][0]["content"], "You are concise.");
1027 assert_eq!(value["messages"][1]["role"], "user");
1028 assert_eq!(value["messages"][1]["content"], "Ping");
1029 assert_eq!(value["stream"], true);
1030 assert_eq!(value["max_tokens"], 123);
1031 let temperature = value["temperature"]
1032 .as_f64()
1033 .expect("temperature should be numeric");
1034 assert!((temperature - 0.2).abs() < 1e-6);
1035 assert_eq!(value["tools"][0]["type"], "function");
1036 assert_eq!(value["tools"][0]["function"]["name"], "search");
1037 assert_eq!(value["tools"][0]["function"]["description"], "Search docs");
1038 assert_eq!(
1039 value["tools"][0]["function"]["parameters"],
1040 json!({
1041 "type": "object",
1042 "properties": {
1043 "q": { "type": "string" }
1044 },
1045 "required": ["q"]
1046 })
1047 );
1048 }
1049
1050 #[test]
1051 fn test_convert_tool_to_cohere_omits_empty_description() {
1052 let tool = ToolDef {
1053 name: "echo".to_string(),
1054 description: " ".to_string(),
1055 parameters: json!({
1056 "type": "object",
1057 "properties": {
1058 "text": { "type": "string" }
1059 }
1060 }),
1061 };
1062
1063 let converted = convert_tool_to_cohere(&tool);
1064 let value = serde_json::to_value(converted).expect("serialize converted tool");
1065 assert_eq!(value["type"], "function");
1066 assert_eq!(value["function"]["name"], "echo");
1067 assert!(value["function"].get("description").is_none());
1068 }
1069
1070 #[test]
1071 fn test_stream_parses_text_and_tool_call() {
1072 let runtime = RuntimeBuilder::current_thread()
1073 .build()
1074 .expect("runtime build");
1075
1076 runtime.block_on(async move {
1077 let events = [
1078 serde_json::json!({ "type": "message-start", "id": "msg_1" }),
1079 serde_json::json!({
1080 "type": "content-start",
1081 "index": 0,
1082 "delta": { "message": { "content": { "type": "text", "text": "Hello" } } }
1083 }),
1084 serde_json::json!({
1085 "type": "content-delta",
1086 "index": 0,
1087 "delta": { "message": { "content": { "text": " world" } } }
1088 }),
1089 serde_json::json!({ "type": "content-end", "index": 0 }),
1090 serde_json::json!({
1091 "type": "tool-call-start",
1092 "delta": { "message": { "tool_calls": { "id": "call_1", "type": "function", "function": { "name": "echo", "arguments": "{\"text\":\"hi\"}" } } } }
1093 }),
1094 serde_json::json!({ "type": "tool-call-end" }),
1095 serde_json::json!({
1096 "type": "message-end",
1097 "delta": { "finish_reason": "TOOL_CALL", "usage": { "tokens": { "input_tokens": 1, "output_tokens": 2 } } }
1098 }),
1099 ];
1100
1101 let byte_stream = stream::iter(
1102 events
1103 .iter()
1104 .map(|event| format!("data: {}\n\n", serde_json::to_string(event).unwrap()))
1105 .map(|s| Ok(s.into_bytes())),
1106 );
1107
1108 let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1109 let mut state = StreamState::new(
1110 event_source,
1111 "command-r".to_string(),
1112 "cohere-chat".to_string(),
1113 "cohere".to_string(),
1114 );
1115
1116 let mut out = Vec::new();
1117 while let Some(item) = state.event_source.next().await {
1118 let msg = item.expect("SSE event");
1119 state.process_event(&msg.data).expect("process_event");
1120 out.extend(state.pending_events.drain(..));
1121 if state.finished {
1122 break;
1123 }
1124 }
1125
1126 assert!(matches!(out.first(), Some(StreamEvent::Start { .. })));
1127 assert!(out.iter().any(|e| matches!(e, StreamEvent::TextDelta { delta, .. } if delta.contains("Hello"))));
1128 assert!(out.iter().any(|e| matches!(e, StreamEvent::ToolCallEnd { tool_call, .. } if tool_call.name == "echo")));
1129 assert!(out.iter().any(|e| matches!(e, StreamEvent::Done { reason: StopReason::ToolUse, .. })));
1130 });
1131 }
1132
1133 #[test]
1134 fn test_stream_parses_thinking_and_max_tokens_stop_reason() {
1135 let events = vec![
1136 json!({ "type": "message-start", "id": "msg_1" }),
1137 json!({
1138 "type": "content-start",
1139 "index": 0,
1140 "delta": { "message": { "content": { "type": "thinking", "thinking": "Plan" } } }
1141 }),
1142 json!({
1143 "type": "content-delta",
1144 "index": 0,
1145 "delta": { "message": { "content": { "thinking": " more" } } }
1146 }),
1147 json!({ "type": "content-end", "index": 0 }),
1148 json!({
1149 "type": "message-end",
1150 "delta": { "finish_reason": "MAX_TOKENS", "usage": { "tokens": { "input_tokens": 2, "output_tokens": 3 } } }
1151 }),
1152 ];
1153
1154 let out = collect_events(&events);
1155 assert!(
1156 out.iter()
1157 .any(|e| matches!(e, StreamEvent::ThinkingStart { .. }))
1158 );
1159 assert!(out.iter().any(
1160 |e| matches!(e, StreamEvent::ThinkingDelta { delta, .. } if delta.contains("Plan"))
1161 ));
1162 assert!(
1163 out.iter()
1164 .any(|e| matches!(e, StreamEvent::ThinkingEnd { content, .. } if content.contains("Plan more")))
1165 );
1166 assert!(out.iter().any(|e| matches!(
1167 e,
1168 StreamEvent::Done {
1169 reason: StopReason::Length,
1170 ..
1171 }
1172 )));
1173 }
1174
1175 #[test]
1176 fn test_stream_sets_bearer_auth_header() {
1177 let captured = run_stream_and_capture_headers(Some("test-cohere-key"), HashMap::new())
1178 .expect("captured request");
1179 assert_eq!(
1180 captured.headers.get("authorization").map(String::as_str),
1181 Some("Bearer test-cohere-key")
1182 );
1183 assert_eq!(
1184 captured.headers.get("accept").map(String::as_str),
1185 Some("text/event-stream")
1186 );
1187
1188 let body: Value = serde_json::from_str(&captured.body).expect("body json");
1189 assert_eq!(body["model"], "command-r");
1190 assert_eq!(body["stream"], true);
1191 }
1192
1193 #[test]
1194 fn test_stream_uses_existing_authorization_header_without_api_key() {
1195 let mut headers = HashMap::new();
1196 headers.insert(
1197 "Authorization".to_string(),
1198 "Bearer from-custom-header".to_string(),
1199 );
1200 headers.insert("X-Test".to_string(), "1".to_string());
1201
1202 let captured = run_stream_and_capture_headers(None, headers).expect("captured request");
1203 assert_eq!(
1204 captured.headers.get("authorization").map(String::as_str),
1205 Some("Bearer from-custom-header")
1206 );
1207 assert_eq!(
1208 captured.headers.get("x-test").map(String::as_str),
1209 Some("1")
1210 );
1211 }
1212
1213 fn collect_events(events: &[Value]) -> Vec<StreamEvent> {
1214 let runtime = RuntimeBuilder::current_thread()
1215 .build()
1216 .expect("runtime build");
1217
1218 runtime.block_on(async {
1219 let byte_stream = stream::iter(
1220 events
1221 .iter()
1222 .map(|event| {
1223 format!(
1224 "data: {}\n\n",
1225 serde_json::to_string(event).expect("serialize event")
1226 )
1227 })
1228 .map(|s| Ok(s.into_bytes())),
1229 );
1230
1231 let event_source = crate::sse::SseStream::new(Box::pin(byte_stream));
1232 let mut state = StreamState::new(
1233 event_source,
1234 "command-r".to_string(),
1235 "cohere-chat".to_string(),
1236 "cohere".to_string(),
1237 );
1238
1239 let mut out = Vec::new();
1240 while let Some(item) = state.event_source.next().await {
1241 let msg = item.expect("SSE event");
1242 state.process_event(&msg.data).expect("process_event");
1243 out.extend(state.pending_events.drain(..));
1244 if state.finished {
1245 break;
1246 }
1247 }
1248 out
1249 })
1250 }
1251
1252 #[derive(Debug)]
1253 struct CapturedRequest {
1254 headers: HashMap<String, String>,
1255 body: String,
1256 }
1257
1258 fn run_stream_and_capture_headers(
1259 api_key: Option<&str>,
1260 extra_headers: HashMap<String, String>,
1261 ) -> Option<CapturedRequest> {
1262 let (base_url, rx) = spawn_test_server(200, "text/event-stream", &success_sse_body());
1263 let provider = CohereProvider::new("command-r").with_base_url(base_url);
1264 let context = Context::owned(
1265 Some("system".to_string()),
1266 vec![Message::User(crate::model::UserMessage {
1267 content: UserContent::Text("ping".to_string()),
1268 timestamp: 0,
1269 })],
1270 Vec::new(),
1271 );
1272 let options = StreamOptions {
1273 api_key: api_key.map(str::to_string),
1274 headers: extra_headers,
1275 ..Default::default()
1276 };
1277
1278 let runtime = RuntimeBuilder::current_thread()
1279 .build()
1280 .expect("runtime build");
1281 runtime.block_on(async {
1282 let mut stream = provider.stream(&context, &options).await.expect("stream");
1283 while let Some(event) = stream.next().await {
1284 if matches!(event.expect("stream event"), StreamEvent::Done { .. }) {
1285 break;
1286 }
1287 }
1288 });
1289
1290 rx.recv_timeout(Duration::from_secs(2)).ok()
1291 }
1292
1293 fn success_sse_body() -> String {
1294 [
1295 r#"data: {"type":"message-start","id":"msg_1"}"#,
1296 "",
1297 r#"data: {"type":"message-end","delta":{"finish_reason":"COMPLETE","usage":{"tokens":{"input_tokens":1,"output_tokens":1}}}}"#,
1298 "",
1299 ]
1300 .join("\n")
1301 }
1302
1303 fn spawn_test_server(
1304 status_code: u16,
1305 content_type: &str,
1306 body: &str,
1307 ) -> (String, mpsc::Receiver<CapturedRequest>) {
1308 let listener = TcpListener::bind("127.0.0.1:0").expect("bind test server");
1309 let addr = listener.local_addr().expect("local addr");
1310 let (tx, rx) = mpsc::channel();
1311 let body = body.to_string();
1312 let content_type = content_type.to_string();
1313
1314 std::thread::spawn(move || {
1315 let (mut socket, _) = listener.accept().expect("accept");
1316 socket
1317 .set_read_timeout(Some(Duration::from_secs(2)))
1318 .expect("set read timeout");
1319
1320 let mut bytes = Vec::new();
1321 let mut chunk = [0_u8; 4096];
1322 loop {
1323 match socket.read(&mut chunk) {
1324 Ok(0) => break,
1325 Ok(n) => {
1326 bytes.extend_from_slice(&chunk[..n]);
1327 if bytes.windows(4).any(|window| window == b"\r\n\r\n") {
1328 break;
1329 }
1330 }
1331 Err(err)
1332 if err.kind() == std::io::ErrorKind::WouldBlock
1333 || err.kind() == std::io::ErrorKind::TimedOut =>
1334 {
1335 break;
1336 }
1337 Err(err) => panic!("read request failed: {err}"),
1338 }
1339 }
1340
1341 let header_end = bytes
1342 .windows(4)
1343 .position(|window| window == b"\r\n\r\n")
1344 .expect("request header boundary");
1345 let header_text = String::from_utf8_lossy(&bytes[..header_end]).to_string();
1346 let headers = parse_headers(&header_text);
1347 let mut request_body = bytes[header_end + 4..].to_vec();
1348
1349 let content_length = headers
1350 .get("content-length")
1351 .and_then(|value| value.parse::<usize>().ok())
1352 .unwrap_or(0);
1353 while request_body.len() < content_length {
1354 match socket.read(&mut chunk) {
1355 Ok(0) => break,
1356 Ok(n) => request_body.extend_from_slice(&chunk[..n]),
1357 Err(err)
1358 if err.kind() == std::io::ErrorKind::WouldBlock
1359 || err.kind() == std::io::ErrorKind::TimedOut =>
1360 {
1361 break;
1362 }
1363 Err(err) => panic!("read request body failed: {err}"),
1364 }
1365 }
1366
1367 let captured = CapturedRequest {
1368 headers,
1369 body: String::from_utf8_lossy(&request_body).to_string(),
1370 };
1371 tx.send(captured).expect("send captured request");
1372
1373 let reason = match status_code {
1374 401 => "Unauthorized",
1375 500 => "Internal Server Error",
1376 _ => "OK",
1377 };
1378 let response = format!(
1379 "HTTP/1.1 {status_code} {reason}\r\nContent-Type: {content_type}\r\nContent-Length: {}\r\nConnection: close\r\n\r\n{body}",
1380 body.len()
1381 );
1382 socket
1383 .write_all(response.as_bytes())
1384 .expect("write response");
1385 socket.flush().expect("flush response");
1386 });
1387
1388 (format!("http://{addr}"), rx)
1389 }
1390
1391 fn parse_headers(header_text: &str) -> HashMap<String, String> {
1392 let mut headers = HashMap::new();
1393 for line in header_text.lines().skip(1) {
1394 if let Some((name, value)) = line.split_once(':') {
1395 headers.insert(name.trim().to_ascii_lowercase(), value.trim().to_string());
1396 }
1397 }
1398 headers
1399 }
1400
1401 #[test]
1404 fn test_build_request_no_system_prompt() {
1405 let provider = CohereProvider::new("command-r-plus");
1406 let context = Context::owned(
1407 None,
1408 vec![Message::User(crate::model::UserMessage {
1409 content: UserContent::Text("Hi".to_string()),
1410 timestamp: 0,
1411 })],
1412 vec![],
1413 );
1414 let options = StreamOptions::default();
1415
1416 let req = provider.build_request(&context, &options);
1417 let value = serde_json::to_value(&req).expect("serialize");
1418
1419 assert_eq!(value["messages"][0]["role"], "user");
1421 assert_eq!(value["messages"][0]["content"], "Hi");
1422 let msgs = value["messages"].as_array().unwrap();
1424 assert!(
1425 !msgs.iter().any(|m| m["role"] == "system"),
1426 "No system message should be present"
1427 );
1428 }
1429
1430 #[test]
1431 fn test_build_request_default_max_tokens() {
1432 let provider = CohereProvider::new("command-r");
1433 let context = Context::owned(
1434 None,
1435 vec![Message::User(crate::model::UserMessage {
1436 content: UserContent::Text("test".to_string()),
1437 timestamp: 0,
1438 })],
1439 vec![],
1440 );
1441 let options = StreamOptions::default();
1442
1443 let req = provider.build_request(&context, &options);
1444 let value = serde_json::to_value(&req).expect("serialize");
1445
1446 assert_eq!(value["max_tokens"], DEFAULT_MAX_TOKENS);
1447 }
1448
1449 #[test]
1450 fn test_build_request_no_tools_omits_tools_field() {
1451 let provider = CohereProvider::new("command-r");
1452 let context = Context::owned(
1453 None,
1454 vec![Message::User(crate::model::UserMessage {
1455 content: UserContent::Text("test".to_string()),
1456 timestamp: 0,
1457 })],
1458 vec![],
1459 );
1460 let options = StreamOptions::default();
1461
1462 let req = provider.build_request(&context, &options);
1463 let value = serde_json::to_value(&req).expect("serialize");
1464
1465 assert!(
1466 value.get("tools").is_none() || value["tools"].is_null(),
1467 "tools field should be omitted when empty"
1468 );
1469 }
1470
1471 #[test]
1472 fn test_build_request_full_conversation_with_tool_call_and_result() {
1473 let provider = CohereProvider::new("command-r");
1474 let context = Context::owned(
1475 Some("Be concise.".to_string()),
1476 vec![
1477 Message::User(crate::model::UserMessage {
1478 content: UserContent::Text("Read /tmp/a.txt".to_string()),
1479 timestamp: 0,
1480 }),
1481 Message::assistant(AssistantMessage {
1482 content: vec![ContentBlock::ToolCall(ToolCall {
1483 id: "call_1".to_string(),
1484 name: "read".to_string(),
1485 arguments: serde_json::json!({"path": "/tmp/a.txt"}),
1486 thought_signature: None,
1487 })],
1488 api: "cohere-chat".to_string(),
1489 provider: "cohere".to_string(),
1490 model: "command-r".to_string(),
1491 usage: Usage::default(),
1492 stop_reason: StopReason::ToolUse,
1493 error_message: None,
1494 timestamp: 1,
1495 }),
1496 Message::tool_result(crate::model::ToolResultMessage {
1497 tool_call_id: "call_1".to_string(),
1498 tool_name: "read".to_string(),
1499 content: vec![ContentBlock::Text(TextContent::new("file contents"))],
1500 details: None,
1501 is_error: false,
1502 timestamp: 2,
1503 }),
1504 ],
1505 vec![ToolDef {
1506 name: "read".to_string(),
1507 description: "Read a file".to_string(),
1508 parameters: json!({"type": "object"}),
1509 }],
1510 );
1511 let options = StreamOptions::default();
1512
1513 let req = provider.build_request(&context, &options);
1514 let value = serde_json::to_value(&req).expect("serialize");
1515
1516 let msgs = value["messages"].as_array().unwrap();
1517 assert_eq!(msgs.len(), 4);
1519 assert_eq!(msgs[0]["role"], "system");
1520 assert_eq!(msgs[1]["role"], "user");
1521 assert_eq!(msgs[2]["role"], "assistant");
1522 assert_eq!(msgs[3]["role"], "tool");
1523
1524 assert!(msgs[2].get("content").is_none() || msgs[2]["content"].is_null());
1526 let tool_calls = msgs[2]["tool_calls"].as_array().unwrap();
1527 assert_eq!(tool_calls.len(), 1);
1528 assert_eq!(tool_calls[0]["id"], "call_1");
1529 assert_eq!(tool_calls[0]["type"], "function");
1530 assert_eq!(tool_calls[0]["function"]["name"], "read");
1531
1532 assert_eq!(msgs[3]["tool_call_id"], "call_1");
1534 assert_eq!(msgs[3]["content"], "file contents");
1535 }
1536
1537 #[test]
1538 fn test_convert_custom_message_to_cohere() {
1539 let context = Context::owned(
1540 None,
1541 vec![Message::Custom(crate::model::CustomMessage {
1542 custom_type: "extension_note".to_string(),
1543 content: "Important context.".to_string(),
1544 display: false,
1545 details: None,
1546 timestamp: 0,
1547 })],
1548 vec![],
1549 );
1550
1551 let msgs = build_cohere_messages(&context);
1552 assert_eq!(msgs.len(), 1);
1553 let value = serde_json::to_value(&msgs[0]).expect("serialize");
1554 assert_eq!(value["role"], "user");
1555 assert_eq!(value["content"], "Important context.");
1556 }
1557
1558 #[test]
1559 fn test_convert_user_blocks_extracts_text_only() {
1560 let content = UserContent::Blocks(vec![
1561 ContentBlock::Text(TextContent::new("part 1")),
1562 ContentBlock::Image(crate::model::ImageContent {
1563 data: "aGVsbG8=".to_string(),
1564 mime_type: "image/png".to_string(),
1565 }),
1566 ContentBlock::Text(TextContent::new("part 2")),
1567 ]);
1568
1569 let text = extract_text_user_content(&content);
1570 assert_eq!(text, "part 1[Image: image/png (8 bytes)]part 2");
1571 }
1572
1573 #[test]
1576 fn test_custom_provider_name() {
1577 let provider = CohereProvider::new("command-r").with_provider_name("my-proxy");
1578 assert_eq!(provider.name(), "my-proxy");
1579 assert_eq!(provider.api(), "cohere-chat");
1580 }
1581
1582 #[test]
1583 fn test_custom_base_url() {
1584 let provider =
1585 CohereProvider::new("command-r").with_base_url("https://proxy.example.com/v2/chat");
1586 assert_eq!(provider.base_url, "https://proxy.example.com/v2/chat");
1587 }
1588
1589 #[test]
1592 fn test_stream_complete_finish_reason_maps_to_stop() {
1593 let events = vec![
1594 json!({ "type": "message-start", "id": "msg_1" }),
1595 json!({
1596 "type": "message-end",
1597 "delta": {
1598 "finish_reason": "COMPLETE",
1599 "usage": { "tokens": { "input_tokens": 5, "output_tokens": 10 } }
1600 }
1601 }),
1602 ];
1603
1604 let out = collect_events(&events);
1605 assert!(out.iter().any(|e| matches!(
1606 e,
1607 StreamEvent::Done {
1608 reason: StopReason::Stop,
1609 message,
1610 ..
1611 } if message.usage.input == 5 && message.usage.output == 10
1612 )));
1613 }
1614
1615 #[test]
1616 fn test_stream_error_finish_reason_maps_to_error() {
1617 let events = vec![
1618 json!({ "type": "message-start", "id": "msg_1" }),
1619 json!({
1620 "type": "message-end",
1621 "delta": {
1622 "finish_reason": "ERROR",
1623 "usage": { "tokens": { "input_tokens": 1, "output_tokens": 0 } }
1624 }
1625 }),
1626 ];
1627
1628 let out = collect_events(&events);
1629 assert!(out.iter().any(|e| matches!(
1630 e,
1631 StreamEvent::Done {
1632 reason: StopReason::Error,
1633 ..
1634 }
1635 )));
1636 }
1637
1638 #[test]
1639 fn test_stream_tool_call_with_streamed_arguments() {
1640 let events = vec![
1641 json!({ "type": "message-start", "id": "msg_1" }),
1642 json!({
1643 "type": "tool-call-start",
1644 "delta": {
1645 "message": {
1646 "tool_calls": {
1647 "id": "call_42",
1648 "type": "function",
1649 "function": { "name": "bash", "arguments": "{\"co" }
1650 }
1651 }
1652 }
1653 }),
1654 json!({
1655 "type": "tool-call-delta",
1656 "delta": {
1657 "message": {
1658 "tool_calls": {
1659 "function": { "arguments": "mmand\"" }
1660 }
1661 }
1662 }
1663 }),
1664 json!({
1665 "type": "tool-call-delta",
1666 "delta": {
1667 "message": {
1668 "tool_calls": {
1669 "function": { "arguments": ": \"ls -la\"}" }
1670 }
1671 }
1672 }
1673 }),
1674 json!({ "type": "tool-call-end" }),
1675 json!({
1676 "type": "message-end",
1677 "delta": {
1678 "finish_reason": "TOOL_CALL",
1679 "usage": { "tokens": { "input_tokens": 10, "output_tokens": 20 } }
1680 }
1681 }),
1682 ];
1683
1684 let out = collect_events(&events);
1685
1686 let tool_end = out
1688 .iter()
1689 .find(|e| matches!(e, StreamEvent::ToolCallEnd { .. }));
1690 assert!(tool_end.is_some(), "Expected ToolCallEnd event");
1691 if let Some(StreamEvent::ToolCallEnd { tool_call, .. }) = tool_end {
1692 assert_eq!(tool_call.name, "bash");
1693 assert_eq!(tool_call.id, "call_42");
1694 assert_eq!(tool_call.arguments["command"], "ls -la");
1695 }
1696 }
1697
1698 #[test]
1699 fn test_stream_unknown_event_type_ignored() {
1700 let events = vec![
1701 json!({ "type": "message-start", "id": "msg_1" }),
1702 json!({ "type": "some-future-event", "data": "ignored" }),
1703 json!({
1704 "type": "content-start",
1705 "index": 0,
1706 "delta": { "message": { "content": { "type": "text", "text": "OK" } } }
1707 }),
1708 json!({ "type": "content-end", "index": 0 }),
1709 json!({
1710 "type": "message-end",
1711 "delta": {
1712 "finish_reason": "COMPLETE",
1713 "usage": { "tokens": { "input_tokens": 1, "output_tokens": 1 } }
1714 }
1715 }),
1716 ];
1717
1718 let out = collect_events(&events);
1719 assert!(out.iter().any(|e| matches!(e, StreamEvent::Done { .. })));
1721 assert!(out.iter().any(|e| matches!(
1722 e,
1723 StreamEvent::TextStart {
1724 content_index: 0,
1725 ..
1726 }
1727 )));
1728 }
1729}
1730
1731#[cfg(feature = "fuzzing")]
1736pub mod fuzz {
1737 use super::*;
1738 use futures::stream;
1739 use std::pin::Pin;
1740
1741 type FuzzStream =
1742 Pin<Box<futures::stream::Empty<std::result::Result<Vec<u8>, std::io::Error>>>>;
1743
1744 pub struct Processor(StreamState<FuzzStream>);
1746
1747 impl Default for Processor {
1748 fn default() -> Self {
1749 Self::new()
1750 }
1751 }
1752
1753 impl Processor {
1754 pub fn new() -> Self {
1756 let empty = stream::empty::<std::result::Result<Vec<u8>, std::io::Error>>();
1757 Self(StreamState::new(
1758 crate::sse::SseStream::new(Box::pin(empty)),
1759 "cohere-fuzz".into(),
1760 "cohere".into(),
1761 "cohere".into(),
1762 ))
1763 }
1764
1765 pub fn process_event(&mut self, data: &str) -> crate::error::Result<Vec<StreamEvent>> {
1767 self.0.process_event(data)?;
1768 Ok(self.0.pending_events.drain(..).collect())
1769 }
1770 }
1771}