1use crate::api::{ChatCompletionChunk, ChatProvider, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5 AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
6 UsageStatus,
7};
8use futures::stream::Stream;
9use std::clone::Clone;
10use std::collections::HashMap;
11use std::sync::Arc;
12#[cfg(feature = "unified_transport")]
13use std::time::Duration;
14
15#[cfg(not(feature = "unified_sse"))]
16#[allow(dead_code)]
17fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
18 let mut i = 0;
19 while i < buffer.len().saturating_sub(1) {
20 if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
21 return Some(i + 2);
22 }
23 if i < buffer.len().saturating_sub(3)
24 && buffer[i] == b'\r'
25 && buffer[i + 1] == b'\n'
26 && buffer[i + 2] == b'\r'
27 && buffer[i + 3] == b'\n'
28 {
29 return Some(i + 4);
30 }
31 i += 1;
32 }
33 None
34}
35
36#[cfg(not(feature = "unified_sse"))]
37#[allow(dead_code)]
38fn parse_sse_event(event_text: &str) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
39 for line in event_text.lines() {
40 let line = line.trim();
41 if let Some(stripped) = line.strip_prefix("data: ") {
42 let data = stripped;
43 if data == "[DONE]" {
44 return Some(Ok(None));
45 }
46 return Some(parse_chunk_data(data));
47 }
48 }
49 None
50}
51
52#[cfg(not(feature = "unified_sse"))]
53fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
54 match serde_json::from_str::<serde_json::Value>(data) {
55 Ok(json) => {
56 let choices = json["choices"]
57 .as_array()
58 .map(|arr| {
59 arr.iter()
60 .enumerate()
61 .map(|(index, choice)| {
62 let delta = &choice["delta"];
63 crate::api::ChoiceDelta {
64 index: index as u32,
65 delta: crate::api::MessageDelta {
66 role: delta["role"].as_str().map(|r| match r {
67 "assistant" => crate::types::Role::Assistant,
68 "user" => crate::types::Role::User,
69 "system" => crate::types::Role::System,
70 _ => crate::types::Role::Assistant,
71 }),
72 content: delta["content"].as_str().map(str::to_string),
73 },
74 finish_reason: choice["finish_reason"].as_str().map(str::to_string),
75 }
76 })
77 .collect()
78 })
79 .unwrap_or_default();
80
81 Ok(Some(ChatCompletionChunk {
82 id: json["id"].as_str().unwrap_or_default().to_string(),
83 object: json["object"]
84 .as_str()
85 .unwrap_or("chat.completion.chunk")
86 .to_string(),
87 created: json["created"].as_u64().unwrap_or(0),
88 model: json["model"].as_str().unwrap_or_default().to_string(),
89 choices,
90 }))
91 }
92 Err(e) => Err(AiLibError::ProviderError(format!(
93 "JSON parse error: {}",
94 e
95 ))),
96 }
97}
98
99fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
100 let mut chunks = Vec::new();
104 let bytes = text.as_bytes();
105 let mut i = 0usize;
106 let len = bytes.len();
107 while i < len {
108 let mut end = std::cmp::min(i + max_len, len);
110 if end < len && !text.is_char_boundary(end) {
112 while end > i && !text.is_char_boundary(end) {
113 end -= 1;
114 }
115 if end == i {
117 end = std::cmp::min(i + max_len, len);
118 while end < len && !text.is_char_boundary(end) {
119 end += 1;
120 }
121 if end == i {
122 end = len;
124 }
125 }
126 }
127
128 let mut cut = end;
130 if end < len {
131 if let Some(pos) = text[i..end].rfind(' ') {
132 cut = i + pos;
133 }
134 }
135 if cut == i {
136 cut = end;
137 }
138
139 let chunk = text[i..cut].to_string();
141 chunks.push(chunk);
142 i = cut;
143 if i < len && bytes[i] == b' ' {
145 i += 1;
146 }
147 }
148 chunks
149}
150use futures::StreamExt;
151use tokio::sync::mpsc;
152use tokio_stream::wrappers::UnboundedReceiverStream;
153
154pub struct CohereAdapter {
155 #[allow(dead_code)] transport: DynHttpTransportRef,
157 api_key: String,
158 base_url: String,
159 metrics: Arc<dyn Metrics>,
160}
161
162impl CohereAdapter {
163 #[allow(dead_code)]
164 fn build_default_timeout_secs() -> u64 {
165 std::env::var("AI_HTTP_TIMEOUT_SECS")
166 .ok()
167 .and_then(|s| s.parse::<u64>().ok())
168 .unwrap_or(30)
169 }
170
171 fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
172 #[cfg(feature = "unified_transport")]
173 {
174 let timeout = Duration::from_secs(Self::build_default_timeout_secs());
175 let client = crate::transport::client_factory::build_shared_client().map_err(|e| {
176 AiLibError::NetworkError(format!("Failed to build http client: {}", e))
177 })?;
178 let t = HttpTransport::with_reqwest_client(client, timeout);
179 Ok(t.boxed())
180 }
181 #[cfg(not(feature = "unified_transport"))]
182 {
183 let t = HttpTransport::new();
184 return Ok(t.boxed());
185 }
186 }
187
188 pub fn new() -> Result<Self, AiLibError> {
190 let api_key = std::env::var("COHERE_API_KEY").map_err(|_| {
191 AiLibError::AuthenticationError(
192 "COHERE_API_KEY environment variable not set".to_string(),
193 )
194 })?;
195 let base_url = std::env::var("COHERE_BASE_URL")
196 .unwrap_or_else(|_| "https://api.cohere.ai".to_string());
197 Ok(Self {
198 transport: Self::build_default_transport()?,
199 api_key,
200 base_url,
201 metrics: Arc::new(NoopMetrics::new()),
202 })
203 }
204
205 pub fn new_with_overrides(
207 api_key: String,
208 base_url: Option<String>,
209 ) -> Result<Self, AiLibError> {
210 let resolved_base = base_url.unwrap_or_else(|| {
211 std::env::var("COHERE_BASE_URL").unwrap_or_else(|_| "https://api.cohere.ai".to_string())
212 });
213 Ok(Self {
214 transport: Self::build_default_transport()?,
215 api_key,
216 base_url: resolved_base,
217 metrics: Arc::new(NoopMetrics::new()),
218 })
219 }
220
221 pub fn with_transport(transport: HttpTransport, api_key: String, base_url: String) -> Self {
223 Self {
224 transport: transport.boxed(),
225 api_key,
226 base_url,
227 metrics: Arc::new(NoopMetrics::new()),
228 }
229 }
230
231 pub fn with_transport_ref(
233 transport: DynHttpTransportRef,
234 api_key: String,
235 base_url: String,
236 ) -> Self {
237 Self {
238 transport,
239 api_key,
240 base_url,
241 metrics: Arc::new(NoopMetrics::new()),
242 }
243 }
244
245 pub fn with_transport_ref_and_metrics(
247 transport: DynHttpTransportRef,
248 api_key: String,
249 base_url: String,
250 metrics: Arc<dyn Metrics>,
251 ) -> Self {
252 Self {
253 transport,
254 api_key,
255 base_url,
256 metrics,
257 }
258 }
259
260 fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
261 let msgs: Vec<serde_json::Value> = request
263 .messages
264 .iter()
265 .map(|msg| {
266 serde_json::json!({
267 "role": match msg.role {
268 Role::System => "system",
269 Role::User => "user",
270 Role::Assistant => "assistant",
271 },
272 "content": msg.content.as_text()
273 })
274 })
275 .collect();
276
277 let mut chat_body = serde_json::json!({
278 "model": request.model,
279 "messages": msgs,
280 });
281
282 if let Some(temp) = request.temperature {
284 chat_body["temperature"] =
285 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
286 }
287 if let Some(max_tokens) = request.max_tokens {
288 chat_body["max_tokens"] =
289 serde_json::Value::Number(serde_json::Number::from(max_tokens));
290 }
291
292 request.apply_extensions(&mut chat_body);
293
294 chat_body
295 }
296
297 fn parse_response(
298 &self,
299 response: serde_json::Value,
300 ) -> Result<ChatCompletionResponse, AiLibError> {
301 let content = if let Some(c) = response.get("choices") {
303 c[0]["message"]["content"]
305 .as_str()
306 .map(|s| s.to_string())
307 .or_else(|| c[0]["text"].as_str().map(|s| s.to_string()))
308 } else if let Some(msg) = response.get("message") {
309 msg.get("content").and_then(|content_array| {
311 content_array
312 .as_array()
313 .and_then(|arr| arr.first())
314 .and_then(|content_obj| {
315 content_obj
316 .get("text")
317 .and_then(|t| t.as_str())
318 .map(|text| text.to_string())
319 })
320 })
321 } else if let Some(gens) = response.get("generations") {
322 gens[0]["text"].as_str().map(|s| s.to_string())
324 } else {
325 None
326 };
327
328 let content = content.unwrap_or_default();
329
330 let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
331 if let Some(fc_val) = response.get("function_call") {
332 if let Ok(fc) =
333 serde_json::from_value::<crate::types::function_call::FunctionCall>(fc_val.clone())
334 {
335 function_call = Some(fc);
336 } else if let Some(name) = fc_val
337 .get("name")
338 .and_then(|v| v.as_str())
339 .map(|s| s.to_string())
340 {
341 let args = fc_val.get("arguments").and_then(|a| {
342 if a.is_string() {
343 serde_json::from_str::<serde_json::Value>(a.as_str().unwrap()).ok()
344 } else {
345 Some(a.clone())
346 }
347 });
348 function_call = Some(crate::types::function_call::FunctionCall {
349 name,
350 arguments: args,
351 });
352 }
353 } else if let Some(tool_calls) = response.get("tool_calls").and_then(|v| v.as_array()) {
354 if let Some(first) = tool_calls.first() {
355 if let Some(func) = first.get("function").or_else(|| first.get("function_call")) {
356 if let Some(name) = func.get("name").and_then(|v| v.as_str()) {
357 let mut args_opt = func.get("arguments").cloned();
358 if let Some(args_val) = &args_opt {
359 if args_val.is_string() {
360 if let Some(s) = args_val.as_str() {
361 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s)
362 {
363 args_opt = Some(parsed);
364 }
365 }
366 }
367 }
368 function_call = Some(crate::types::function_call::FunctionCall {
369 name: name.to_string(),
370 arguments: args_opt,
371 });
372 }
373 }
374 }
375 }
376
377 let choice = Choice {
378 index: 0,
379 message: Message {
380 role: Role::Assistant,
381 content: crate::types::common::Content::Text(content.clone()),
382 function_call,
383 },
384 finish_reason: None,
385 };
386
387 let usage = Usage {
388 prompt_tokens: 0,
389 completion_tokens: 0,
390 total_tokens: 0,
391 };
392
393 Ok(ChatCompletionResponse {
394 id: response["id"].as_str().unwrap_or_default().to_string(),
395 object: response["object"].as_str().unwrap_or_default().to_string(),
396 created: response["created"].as_u64().unwrap_or(0),
397 model: response["model"].as_str().unwrap_or_default().to_string(),
398 choices: vec![choice],
399 usage,
400 usage_status: UsageStatus::Unsupported, })
402 }
403}
404
405#[async_trait::async_trait]
406impl ChatProvider for CohereAdapter {
407 fn name(&self) -> &str {
408 "Cohere"
409 }
410
411 async fn chat(
412 &self,
413 request: ChatCompletionRequest,
414 ) -> Result<ChatCompletionResponse, AiLibError> {
415 self.metrics.incr_counter("cohere.requests", 1).await;
416 let timer = self.metrics.start_timer("cohere.request_duration_ms").await;
417
418 let _body = self.convert_request(&request);
419
420 let url_generate = format!("{}/v1/generate", self.base_url);
422
423 let mut headers = HashMap::new();
424 headers.insert(
425 "Authorization".to_string(),
426 format!("Bearer {}", self.api_key),
427 );
428 headers.insert("Content-Type".to_string(), "application/json".to_string());
429 headers.insert("Accept".to_string(), "application/json".to_string());
430
431 let prompt = request
433 .messages
434 .iter()
435 .map(|msg| match msg.role {
436 Role::System => format!("System: {}", msg.content.as_text()),
437 Role::User => format!("Human: {}", msg.content.as_text()),
438 Role::Assistant => format!("Assistant: {}", msg.content.as_text()),
439 })
440 .collect::<Vec<_>>()
441 .join("\n");
442
443 let mut generate_body = serde_json::json!({
444 "model": request.model,
445 "prompt": prompt,
446 });
447
448 if let Some(temp) = request.temperature {
449 generate_body["temperature"] =
450 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
451 }
452 if let Some(max_tokens) = request.max_tokens {
453 generate_body["max_tokens"] =
454 serde_json::Value::Number(serde_json::Number::from(max_tokens));
455 }
456
457 request.apply_extensions(&mut generate_body);
458
459 let response_json = self
461 .transport
462 .post_json(&url_generate, Some(headers), generate_body)
463 .await?;
464
465 if let Some(t) = timer {
466 t.stop();
467 }
468
469 self.parse_response(response_json)
470 }
471
472 async fn stream(
473 &self,
474 _request: ChatCompletionRequest,
475 ) -> Result<
476 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
477 AiLibError,
478 > {
479 let mut stream_request = self.convert_request(&_request);
481 stream_request["stream"] = serde_json::Value::Bool(true);
482
483 let url = format!("{}/v1/chat", self.base_url);
484
485 let mut headers = HashMap::new();
486 headers.insert(
487 "Authorization".to_string(),
488 format!("Bearer {}", self.api_key),
489 );
490 if let Ok(byte_stream) = self
492 .transport
493 .post_stream(&url, Some(headers.clone()), stream_request.clone())
494 .await
495 {
496 let (tx, rx) = mpsc::unbounded_channel();
497 tokio::spawn(async move {
498 let mut buffer = Vec::new();
499 futures::pin_mut!(byte_stream);
500 while let Some(item) = byte_stream.next().await {
501 match item {
502 Ok(bytes) => {
503 buffer.extend_from_slice(&bytes);
504 #[cfg(feature = "unified_sse")]
505 {
506 while let Some(boundary) =
507 crate::sse::parser::find_event_boundary(&buffer)
508 {
509 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
510 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
511 if let Some(parsed) =
512 crate::sse::parser::parse_sse_event(event_text)
513 {
514 match parsed {
515 Ok(Some(chunk)) => {
516 if tx.send(Ok(chunk)).is_err() {
517 return;
518 }
519 }
520 Ok(None) => return,
521 Err(e) => {
522 let _ = tx.send(Err(e));
523 return;
524 }
525 }
526 }
527 }
528 }
529 }
530 #[cfg(not(feature = "unified_sse"))]
531 {
532 while let Some(boundary) = find_event_boundary(&buffer) {
533 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
534 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
535 if let Some(parsed) = parse_sse_event(event_text) {
536 match parsed {
537 Ok(Some(chunk)) => {
538 if tx.send(Ok(chunk)).is_err() {
539 return;
540 }
541 }
542 Ok(None) => return,
543 Err(e) => {
544 let _ = tx.send(Err(e));
545 return;
546 }
547 }
548 }
549 }
550 }
551 }
552 }
553 Err(e) => {
554 let _ = tx.send(Err(AiLibError::ProviderError(format!(
555 "Stream error: {}",
556 e
557 ))));
558 break;
559 }
560 }
561 }
562 });
563 let stream = UnboundedReceiverStream::new(rx);
564 return Ok(Box::new(Box::pin(stream)));
565 }
566
567 let finished = self.chat(_request.clone()).await?;
569 let text = finished
570 .choices
571 .first()
572 .map(|c| c.message.content.as_text())
573 .unwrap_or_default();
574
575 let (tx, rx) = mpsc::unbounded_channel();
576
577 tokio::spawn(async move {
578 let chunks = split_text_into_chunks(&text, 80);
579 for chunk in chunks {
580 let delta = crate::api::ChoiceDelta {
582 index: 0,
583 delta: crate::api::MessageDelta {
584 role: Some(crate::types::Role::Assistant),
585 content: Some(chunk.clone()),
586 },
587 finish_reason: None,
588 };
589 let chunk_obj = ChatCompletionChunk {
590 id: "simulated".to_string(),
591 object: "chat.completion.chunk".to_string(),
592 created: 0,
593 model: finished.model.clone(),
594 choices: vec![delta],
595 };
596
597 if tx.send(Ok(chunk_obj)).is_err() {
598 return;
599 }
600 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
601 }
602 });
603
604 let stream = UnboundedReceiverStream::new(rx);
605 Ok(Box::new(Box::pin(stream)))
606 }
607
608 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
609 let url = format!("{}/v1/models", self.base_url);
611 let mut headers = HashMap::new();
612 headers.insert(
613 "Authorization".to_string(),
614 format!("Bearer {}", self.api_key),
615 );
616
617 let response = self.transport.get_json(&url, Some(headers)).await?;
618
619 Ok(response["models"]
620 .as_array()
621 .unwrap_or(&vec![])
622 .iter()
623 .filter_map(|m| {
624 m["id"]
625 .as_str()
626 .map(|s| s.to_string())
627 .or_else(|| m["name"].as_str().map(|s| s.to_string()))
628 })
629 .collect())
630 }
631
632 async fn get_model_info(&self, model_id: &str) -> Result<crate::api::ModelInfo, AiLibError> {
633 Ok(ModelInfo {
634 id: model_id.to_string(),
635 object: "model".to_string(),
636 created: 0,
637 owned_by: "cohere".to_string(),
638 permission: vec![ModelPermission {
639 id: "default".to_string(),
640 object: "model_permission".to_string(),
641 created: 0,
642 allow_create_engine: false,
643 allow_sampling: true,
644 allow_logprobs: false,
645 allow_search_indices: false,
646 allow_view: true,
647 allow_fine_tuning: false,
648 organization: "*".to_string(),
649 group: None,
650 is_blocking: false,
651 }],
652 })
653 }
654}
655
656#[cfg(test)]
657mod tests {
658 use super::*;
659
660 #[test]
661 fn split_text_handles_multibyte_without_panic() {
662 let s = "这是一个包含中文字符的长文本,用于测试边界处理。🚀✨";
663 let chunks = split_text_into_chunks(s, 8);
665 assert!(!chunks.is_empty());
666 for c in &chunks {
668 assert!(std::str::from_utf8(c.as_bytes()).is_ok());
669 assert!(c.len() <= 8 + 8);
671 }
672 let combined = chunks.join("");
675 assert!(s.contains(&combined) || combined.contains(s) || !combined.is_empty());
676 }
677}