1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5 AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage, UsageStatus,
6};
7use futures::stream::Stream;
8use std::clone::Clone;
9use std::collections::HashMap;
10use std::sync::Arc;
11#[cfg(feature = "unified_transport")]
12use std::time::Duration;
13
14#[cfg(not(feature = "unified_sse"))]
15#[allow(dead_code)]
16fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
17 let mut i = 0;
18 while i < buffer.len().saturating_sub(1) {
19 if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
20 return Some(i + 2);
21 }
22 if i < buffer.len().saturating_sub(3)
23 && buffer[i] == b'\r'
24 && buffer[i + 1] == b'\n'
25 && buffer[i + 2] == b'\r'
26 && buffer[i + 3] == b'\n'
27 {
28 return Some(i + 4);
29 }
30 i += 1;
31 }
32 None
33}
34
35#[cfg(not(feature = "unified_sse"))]
36#[allow(dead_code)]
37fn parse_sse_event(event_text: &str) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
38 for line in event_text.lines() {
39 let line = line.trim();
40 if let Some(stripped) = line.strip_prefix("data: ") {
41 let data = stripped;
42 if data == "[DONE]" {
43 return Some(Ok(None));
44 }
45 return Some(parse_chunk_data(data));
46 }
47 }
48 None
49}
50
51#[cfg(not(feature = "unified_sse"))]
52fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
53 match serde_json::from_str::<serde_json::Value>(data) {
54 Ok(json) => {
55 let choices = json["choices"]
56 .as_array()
57 .map(|arr| {
58 arr.iter()
59 .enumerate()
60 .map(|(index, choice)| {
61 let delta = &choice["delta"];
62 crate::api::ChoiceDelta {
63 index: index as u32,
64 delta: crate::api::MessageDelta {
65 role: delta["role"].as_str().map(|r| match r {
66 "assistant" => crate::types::Role::Assistant,
67 "user" => crate::types::Role::User,
68 "system" => crate::types::Role::System,
69 _ => crate::types::Role::Assistant,
70 }),
71 content: delta["content"].as_str().map(str::to_string),
72 },
73 finish_reason: choice["finish_reason"].as_str().map(str::to_string),
74 }
75 })
76 .collect()
77 })
78 .unwrap_or_default();
79
80 Ok(Some(ChatCompletionChunk {
81 id: json["id"].as_str().unwrap_or_default().to_string(),
82 object: json["object"]
83 .as_str()
84 .unwrap_or("chat.completion.chunk")
85 .to_string(),
86 created: json["created"].as_u64().unwrap_or(0),
87 model: json["model"].as_str().unwrap_or_default().to_string(),
88 choices,
89 }))
90 }
91 Err(e) => Err(AiLibError::ProviderError(format!(
92 "JSON parse error: {}",
93 e
94 ))),
95 }
96}
97
98fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
99 let mut chunks = Vec::new();
103 let bytes = text.as_bytes();
104 let mut i = 0usize;
105 let len = bytes.len();
106 while i < len {
107 let mut end = std::cmp::min(i + max_len, len);
109 if end < len && !text.is_char_boundary(end) {
111 while end > i && !text.is_char_boundary(end) {
112 end -= 1;
113 }
114 if end == i {
116 end = std::cmp::min(i + max_len, len);
117 while end < len && !text.is_char_boundary(end) {
118 end += 1;
119 }
120 if end == i {
121 end = len;
123 }
124 }
125 }
126
127 let mut cut = end;
129 if end < len {
130 if let Some(pos) = text[i..end].rfind(' ') {
131 cut = i + pos;
132 }
133 }
134 if cut == i {
135 cut = end;
136 }
137
138 let chunk = text[i..cut].to_string();
140 chunks.push(chunk);
141 i = cut;
142 if i < len && bytes[i] == b' ' {
144 i += 1;
145 }
146 }
147 chunks
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn split_text_handles_multibyte_without_panic() {
156 let s = "这是一个包含中文字符的长文本,用于测试边界处理。🚀✨";
157 let chunks = split_text_into_chunks(s, 8);
159 assert!(!chunks.is_empty());
160 for c in &chunks {
162 assert!(std::str::from_utf8(c.as_bytes()).is_ok());
163 assert!(c.as_bytes().len() <= 8 + 8);
165 }
166 let combined = chunks.join("");
169 assert!(s.contains(&combined) || combined.contains(&s) || !combined.is_empty());
170 }
171}
172use futures::StreamExt;
173use tokio::sync::mpsc;
174use tokio_stream::wrappers::UnboundedReceiverStream;
175
176pub struct CohereAdapter {
177 #[allow(dead_code)] transport: DynHttpTransportRef,
179 api_key: String,
180 base_url: String,
181 metrics: Arc<dyn Metrics>,
182}
183
184impl CohereAdapter {
185 #[allow(dead_code)]
186 fn build_default_timeout_secs() -> u64 {
187 std::env::var("AI_HTTP_TIMEOUT_SECS")
188 .ok()
189 .and_then(|s| s.parse::<u64>().ok())
190 .unwrap_or(30)
191 }
192
193 fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
194 #[cfg(feature = "unified_transport")]
195 {
196 let timeout = Duration::from_secs(Self::build_default_timeout_secs());
197 let client = crate::transport::client_factory::build_shared_client()
198 .map_err(|e| AiLibError::NetworkError(format!("Failed to build http client: {}", e)))?;
199 let t = HttpTransport::with_reqwest_client(client, timeout);
200 return Ok(t.boxed());
201 }
202 #[cfg(not(feature = "unified_transport"))]
203 {
204 let t = HttpTransport::new();
205 return Ok(t.boxed());
206 }
207 }
208
209 pub fn new() -> Result<Self, AiLibError> {
211 let api_key = std::env::var("COHERE_API_KEY").map_err(|_| {
212 AiLibError::AuthenticationError(
213 "COHERE_API_KEY environment variable not set".to_string(),
214 )
215 })?;
216 let base_url = std::env::var("COHERE_BASE_URL")
217 .unwrap_or_else(|_| "https://api.cohere.ai".to_string());
218 Ok(Self {
219 transport: Self::build_default_transport()?,
220 api_key,
221 base_url,
222 metrics: Arc::new(NoopMetrics::new()),
223 })
224 }
225
226 pub fn new_with_overrides(
228 api_key: String,
229 base_url: Option<String>,
230 ) -> Result<Self, AiLibError> {
231 let resolved_base = base_url.unwrap_or_else(|| {
232 std::env::var("COHERE_BASE_URL").unwrap_or_else(|_| "https://api.cohere.ai".to_string())
233 });
234 Ok(Self {
235 transport: Self::build_default_transport()?,
236 api_key,
237 base_url: resolved_base,
238 metrics: Arc::new(NoopMetrics::new()),
239 })
240 }
241
242 pub fn with_transport(transport: HttpTransport, api_key: String, base_url: String) -> Self {
244 Self {
245 transport: transport.boxed(),
246 api_key,
247 base_url,
248 metrics: Arc::new(NoopMetrics::new()),
249 }
250 }
251
252 pub fn with_transport_ref(
254 transport: DynHttpTransportRef,
255 api_key: String,
256 base_url: String,
257 ) -> Self {
258 Self {
259 transport,
260 api_key,
261 base_url,
262 metrics: Arc::new(NoopMetrics::new()),
263 }
264 }
265
266 pub fn with_transport_ref_and_metrics(
268 transport: DynHttpTransportRef,
269 api_key: String,
270 base_url: String,
271 metrics: Arc<dyn Metrics>,
272 ) -> Self {
273 Self {
274 transport,
275 api_key,
276 base_url,
277 metrics,
278 }
279 }
280
281 fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
282 let msgs: Vec<serde_json::Value> = request
284 .messages
285 .iter()
286 .map(|msg| {
287 serde_json::json!({
288 "role": match msg.role {
289 Role::System => "system",
290 Role::User => "user",
291 Role::Assistant => "assistant",
292 },
293 "content": msg.content.as_text()
294 })
295 })
296 .collect();
297
298 let mut chat_body = serde_json::json!({
299 "model": request.model,
300 "messages": msgs,
301 });
302
303 if let Some(temp) = request.temperature {
305 chat_body["temperature"] =
306 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
307 }
308 if let Some(max_tokens) = request.max_tokens {
309 chat_body["max_tokens"] =
310 serde_json::Value::Number(serde_json::Number::from(max_tokens));
311 }
312
313 chat_body
314 }
315
316 fn parse_response(
317 &self,
318 response: serde_json::Value,
319 ) -> Result<ChatCompletionResponse, AiLibError> {
320 let content = if let Some(c) = response.get("choices") {
322 c[0]["message"]["content"]
324 .as_str()
325 .map(|s| s.to_string())
326 .or_else(|| c[0]["text"].as_str().map(|s| s.to_string()))
327 } else if let Some(msg) = response.get("message") {
328 msg.get("content").and_then(|content_array| {
330 content_array
331 .as_array()
332 .and_then(|arr| arr.first())
333 .and_then(|content_obj| {
334 content_obj
335 .get("text")
336 .and_then(|t| t.as_str())
337 .map(|text| text.to_string())
338 })
339 })
340 } else if let Some(gens) = response.get("generations") {
341 gens[0]["text"].as_str().map(|s| s.to_string())
343 } else {
344 None
345 };
346
347 let content = content.unwrap_or_default();
348
349 let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
350 if let Some(fc_val) = response.get("function_call") {
351 if let Ok(fc) =
352 serde_json::from_value::<crate::types::function_call::FunctionCall>(fc_val.clone())
353 {
354 function_call = Some(fc);
355 } else if let Some(name) = fc_val
356 .get("name")
357 .and_then(|v| v.as_str())
358 .map(|s| s.to_string())
359 {
360 let args = fc_val.get("arguments").and_then(|a| {
361 if a.is_string() {
362 serde_json::from_str::<serde_json::Value>(a.as_str().unwrap()).ok()
363 } else {
364 Some(a.clone())
365 }
366 });
367 function_call = Some(crate::types::function_call::FunctionCall {
368 name,
369 arguments: args,
370 });
371 }
372 } else if let Some(tool_calls) = response.get("tool_calls").and_then(|v| v.as_array()) {
373 if let Some(first) = tool_calls.first() {
374 if let Some(func) = first.get("function").or_else(|| first.get("function_call")) {
375 if let Some(name) = func.get("name").and_then(|v| v.as_str()) {
376 let mut args_opt = func.get("arguments").cloned();
377 if let Some(args_val) = &args_opt {
378 if args_val.is_string() {
379 if let Some(s) = args_val.as_str() {
380 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s) {
381 args_opt = Some(parsed);
382 }
383 }
384 }
385 }
386 function_call = Some(crate::types::function_call::FunctionCall {
387 name: name.to_string(),
388 arguments: args_opt,
389 });
390 }
391 }
392 }
393 }
394
395 let choice = Choice {
396 index: 0,
397 message: Message {
398 role: Role::Assistant,
399 content: crate::types::common::Content::Text(content.clone()),
400 function_call,
401 },
402 finish_reason: None,
403 };
404
405 let usage = Usage {
406 prompt_tokens: 0,
407 completion_tokens: 0,
408 total_tokens: 0,
409 };
410
411 Ok(ChatCompletionResponse {
412 id: response["id"].as_str().unwrap_or_default().to_string(),
413 object: response["object"].as_str().unwrap_or_default().to_string(),
414 created: response["created"].as_u64().unwrap_or(0),
415 model: response["model"].as_str().unwrap_or_default().to_string(),
416 choices: vec![choice],
417 usage,
418 usage_status: UsageStatus::Unsupported, })
420 }
421}
422
423#[async_trait::async_trait]
424impl ChatApi for CohereAdapter {
425 async fn chat_completion(
426 &self,
427 request: ChatCompletionRequest,
428 ) -> Result<ChatCompletionResponse, AiLibError> {
429 self.metrics.incr_counter("cohere.requests", 1).await;
430 let timer = self.metrics.start_timer("cohere.request_duration_ms").await;
431
432 let _body = self.convert_request(&request);
433
434 let url_generate = format!("{}/v1/generate", self.base_url);
436
437 let mut headers = HashMap::new();
438 headers.insert(
439 "Authorization".to_string(),
440 format!("Bearer {}", self.api_key),
441 );
442 headers.insert("Content-Type".to_string(), "application/json".to_string());
443 headers.insert("Accept".to_string(), "application/json".to_string());
444
445 let prompt = request
447 .messages
448 .iter()
449 .map(|msg| match msg.role {
450 Role::System => format!("System: {}", msg.content.as_text()),
451 Role::User => format!("Human: {}", msg.content.as_text()),
452 Role::Assistant => format!("Assistant: {}", msg.content.as_text()),
453 })
454 .collect::<Vec<_>>()
455 .join("\n");
456
457 let mut generate_body = serde_json::json!({
458 "model": request.model,
459 "prompt": prompt,
460 });
461
462 if let Some(temp) = request.temperature {
463 generate_body["temperature"] =
464 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
465 }
466 if let Some(max_tokens) = request.max_tokens {
467 generate_body["max_tokens"] =
468 serde_json::Value::Number(serde_json::Number::from(max_tokens));
469 }
470
471 let response_json = self
473 .transport
474 .post_json(&url_generate, Some(headers), generate_body)
475 .await?;
476
477 if let Some(t) = timer {
478 t.stop();
479 }
480
481 self.parse_response(response_json)
482 }
483
484 async fn chat_completion_stream(
485 &self,
486 _request: ChatCompletionRequest,
487 ) -> Result<
488 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
489 AiLibError,
490 > {
491 let mut stream_request = self.convert_request(&_request);
493 stream_request["stream"] = serde_json::Value::Bool(true);
494
495 let url = format!("{}/v1/chat", self.base_url);
496
497 let mut headers = HashMap::new();
498 headers.insert(
499 "Authorization".to_string(),
500 format!("Bearer {}", self.api_key),
501 );
502 if let Ok(byte_stream) = self
504 .transport
505 .post_stream(&url, Some(headers.clone()), stream_request.clone())
506 .await
507 {
508 let (tx, rx) = mpsc::unbounded_channel();
509 tokio::spawn(async move {
510 let mut buffer = Vec::new();
511 futures::pin_mut!(byte_stream);
512 while let Some(item) = byte_stream.next().await {
513 match item {
514 Ok(bytes) => {
515 buffer.extend_from_slice(&bytes);
516 #[cfg(feature = "unified_sse")]
517 {
518 while let Some(boundary) =
519 crate::sse::parser::find_event_boundary(&buffer)
520 {
521 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
522 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
523 if let Some(parsed) =
524 crate::sse::parser::parse_sse_event(event_text)
525 {
526 match parsed {
527 Ok(Some(chunk)) => {
528 if tx.send(Ok(chunk)).is_err() {
529 return;
530 }
531 }
532 Ok(None) => return,
533 Err(e) => {
534 let _ = tx.send(Err(e));
535 return;
536 }
537 }
538 }
539 }
540 }
541 }
542 #[cfg(not(feature = "unified_sse"))]
543 {
544 while let Some(boundary) = find_event_boundary(&buffer) {
545 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
546 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
547 if let Some(parsed) = parse_sse_event(event_text) {
548 match parsed {
549 Ok(Some(chunk)) => {
550 if tx.send(Ok(chunk)).is_err() {
551 return;
552 }
553 }
554 Ok(None) => return,
555 Err(e) => {
556 let _ = tx.send(Err(e));
557 return;
558 }
559 }
560 }
561 }
562 }
563 }
564 }
565 Err(e) => {
566 let _ = tx.send(Err(AiLibError::ProviderError(format!(
567 "Stream error: {}",
568 e
569 ))));
570 break;
571 }
572 }
573 }
574 });
575 let stream = UnboundedReceiverStream::new(rx);
576 return Ok(Box::new(Box::pin(stream)));
577 }
578
579 let finished = self.chat_completion(_request.clone()).await?;
581 let text = finished
582 .choices
583 .first()
584 .map(|c| c.message.content.as_text())
585 .unwrap_or_default();
586
587 let (tx, rx) = mpsc::unbounded_channel();
588
589 tokio::spawn(async move {
590 let chunks = split_text_into_chunks(&text, 80);
591 for chunk in chunks {
592 let delta = crate::api::ChoiceDelta {
594 index: 0,
595 delta: crate::api::MessageDelta {
596 role: Some(crate::types::Role::Assistant),
597 content: Some(chunk.clone()),
598 },
599 finish_reason: None,
600 };
601 let chunk_obj = ChatCompletionChunk {
602 id: "simulated".to_string(),
603 object: "chat.completion.chunk".to_string(),
604 created: 0,
605 model: finished.model.clone(),
606 choices: vec![delta],
607 };
608
609 if tx.send(Ok(chunk_obj)).is_err() {
610 return;
611 }
612 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
613 }
614 });
615
616 let stream = UnboundedReceiverStream::new(rx);
617 Ok(Box::new(Box::pin(stream)))
618 }
619
620 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
621 let url = format!("{}/v1/models", self.base_url);
623 let mut headers = HashMap::new();
624 headers.insert(
625 "Authorization".to_string(),
626 format!("Bearer {}", self.api_key),
627 );
628
629 let response = self.transport.get_json(&url, Some(headers)).await?;
630
631 Ok(response["models"]
632 .as_array()
633 .unwrap_or(&vec![])
634 .iter()
635 .filter_map(|m| {
636 m["id"]
637 .as_str()
638 .map(|s| s.to_string())
639 .or_else(|| m["name"].as_str().map(|s| s.to_string()))
640 })
641 .collect())
642 }
643
644 async fn get_model_info(&self, model_id: &str) -> Result<crate::api::ModelInfo, AiLibError> {
645 Ok(ModelInfo {
646 id: model_id.to_string(),
647 object: "model".to_string(),
648 created: 0,
649 owned_by: "cohere".to_string(),
650 permission: vec![ModelPermission {
651 id: "default".to_string(),
652 object: "model_permission".to_string(),
653 created: 0,
654 allow_create_engine: false,
655 allow_sampling: true,
656 allow_logprobs: false,
657 allow_search_indices: false,
658 allow_view: true,
659 allow_fine_tuning: false,
660 organization: "*".to_string(),
661 group: None,
662 is_blocking: false,
663 }],
664 })
665 }
666}