1use crate::api::{ChatApi, ChatCompletionChunk, ModelInfo, ModelPermission};
2use crate::metrics::{Metrics, NoopMetrics};
3use crate::transport::{DynHttpTransportRef, HttpTransport};
4use crate::types::{
5 AiLibError, ChatCompletionRequest, ChatCompletionResponse, Choice, Message, Role, Usage,
6};
7use futures::stream::Stream;
8use std::clone::Clone;
9use std::collections::HashMap;
10use std::sync::Arc;
11#[cfg(feature = "unified_transport")]
12use std::time::Duration;
13
14#[cfg(not(feature = "unified_sse"))]
15fn find_event_boundary(buffer: &[u8]) -> Option<usize> {
16 let mut i = 0;
17 while i < buffer.len().saturating_sub(1) {
18 if buffer[i] == b'\n' && buffer[i + 1] == b'\n' {
19 return Some(i + 2);
20 }
21 if i < buffer.len().saturating_sub(3)
22 && buffer[i] == b'\r'
23 && buffer[i + 1] == b'\n'
24 && buffer[i + 2] == b'\r'
25 && buffer[i + 3] == b'\n'
26 {
27 return Some(i + 4);
28 }
29 i += 1;
30 }
31 None
32}
33
34#[cfg(not(feature = "unified_sse"))]
35fn parse_sse_event(event_text: &str) -> Option<Result<Option<ChatCompletionChunk>, AiLibError>> {
36 for line in event_text.lines() {
37 let line = line.trim();
38 if let Some(stripped) = line.strip_prefix("data: ") {
39 let data = stripped;
40 if data == "[DONE]" {
41 return Some(Ok(None));
42 }
43 return Some(parse_chunk_data(data));
44 }
45 }
46 None
47}
48
49#[cfg(not(feature = "unified_sse"))]
50fn parse_chunk_data(data: &str) -> Result<Option<ChatCompletionChunk>, AiLibError> {
51 match serde_json::from_str::<serde_json::Value>(data) {
52 Ok(json) => {
53 let choices = json["choices"]
54 .as_array()
55 .map(|arr| {
56 arr.iter()
57 .enumerate()
58 .map(|(index, choice)| {
59 let delta = &choice["delta"];
60 crate::api::ChoiceDelta {
61 index: index as u32,
62 delta: crate::api::MessageDelta {
63 role: delta["role"].as_str().map(|r| match r {
64 "assistant" => crate::types::Role::Assistant,
65 "user" => crate::types::Role::User,
66 "system" => crate::types::Role::System,
67 _ => crate::types::Role::Assistant,
68 }),
69 content: delta["content"].as_str().map(str::to_string),
70 },
71 finish_reason: choice["finish_reason"].as_str().map(str::to_string),
72 }
73 })
74 .collect()
75 })
76 .unwrap_or_default();
77
78 Ok(Some(ChatCompletionChunk {
79 id: json["id"].as_str().unwrap_or_default().to_string(),
80 object: json["object"]
81 .as_str()
82 .unwrap_or("chat.completion.chunk")
83 .to_string(),
84 created: json["created"].as_u64().unwrap_or(0),
85 model: json["model"].as_str().unwrap_or_default().to_string(),
86 choices,
87 }))
88 }
89 Err(e) => Err(AiLibError::ProviderError(format!(
90 "JSON parse error: {}",
91 e
92 ))),
93 }
94}
95
96fn split_text_into_chunks(text: &str, max_len: usize) -> Vec<String> {
97 let mut chunks = Vec::new();
98 let mut start = 0;
99 let s = text.as_bytes();
100 while start < s.len() {
101 let end = std::cmp::min(start + max_len, s.len());
102 let mut cut = end;
104 if end < s.len() {
105 if let Some(pos) = text[start..end].rfind(' ') {
106 cut = start + pos;
107 }
108 }
109 if cut == start {
110 cut = end;
111 }
112 let chunk = String::from_utf8_lossy(&s[start..cut]).to_string();
113 chunks.push(chunk);
114 start = cut;
115 if start < s.len() && s[start] == b' ' {
116 start += 1;
117 }
118 }
119 chunks
120}
121use futures::StreamExt;
122use tokio::sync::mpsc;
123use tokio_stream::wrappers::UnboundedReceiverStream;
124
125pub struct CohereAdapter {
126 #[allow(dead_code)] transport: DynHttpTransportRef,
128 api_key: String,
129 base_url: String,
130 metrics: Arc<dyn Metrics>,
131}
132
133impl CohereAdapter {
134 fn build_default_timeout_secs() -> u64 {
135 std::env::var("AI_HTTP_TIMEOUT_SECS")
136 .ok()
137 .and_then(|s| s.parse::<u64>().ok())
138 .unwrap_or(30)
139 }
140
141 fn build_default_transport() -> Result<DynHttpTransportRef, AiLibError> {
142 #[cfg(feature = "unified_transport")]
143 {
144 let timeout = Duration::from_secs(Self::build_default_timeout_secs());
145 let client = crate::transport::client_factory::build_shared_client()
146 .map_err(|e| AiLibError::NetworkError(format!("Failed to build http client: {}", e)))?;
147 let t = HttpTransport::with_reqwest_client(client, timeout);
148 return Ok(t.boxed());
149 }
150 #[cfg(not(feature = "unified_transport"))]
151 {
152 let t = HttpTransport::new();
153 return Ok(t.boxed());
154 }
155 }
156
157 pub fn new() -> Result<Self, AiLibError> {
159 let api_key = std::env::var("COHERE_API_KEY").map_err(|_| {
160 AiLibError::AuthenticationError(
161 "COHERE_API_KEY environment variable not set".to_string(),
162 )
163 })?;
164 let base_url = std::env::var("COHERE_BASE_URL")
165 .unwrap_or_else(|_| "https://api.cohere.ai".to_string());
166 Ok(Self {
167 transport: Self::build_default_transport()?,
168 api_key,
169 base_url,
170 metrics: Arc::new(NoopMetrics::new()),
171 })
172 }
173
174 pub fn new_with_overrides(
176 api_key: String,
177 base_url: Option<String>,
178 ) -> Result<Self, AiLibError> {
179 let resolved_base = base_url.unwrap_or_else(|| {
180 std::env::var("COHERE_BASE_URL").unwrap_or_else(|_| "https://api.cohere.ai".to_string())
181 });
182 Ok(Self {
183 transport: Self::build_default_transport()?,
184 api_key,
185 base_url: resolved_base,
186 metrics: Arc::new(NoopMetrics::new()),
187 })
188 }
189
190 pub fn with_transport(transport: HttpTransport, api_key: String, base_url: String) -> Self {
192 Self {
193 transport: transport.boxed(),
194 api_key,
195 base_url,
196 metrics: Arc::new(NoopMetrics::new()),
197 }
198 }
199
200 pub fn with_transport_ref(
202 transport: DynHttpTransportRef,
203 api_key: String,
204 base_url: String,
205 ) -> Self {
206 Self {
207 transport,
208 api_key,
209 base_url,
210 metrics: Arc::new(NoopMetrics::new()),
211 }
212 }
213
214 pub fn with_transport_ref_and_metrics(
216 transport: DynHttpTransportRef,
217 api_key: String,
218 base_url: String,
219 metrics: Arc<dyn Metrics>,
220 ) -> Self {
221 Self {
222 transport,
223 api_key,
224 base_url,
225 metrics,
226 }
227 }
228
229 fn convert_request(&self, request: &ChatCompletionRequest) -> serde_json::Value {
230 let msgs: Vec<serde_json::Value> = request
232 .messages
233 .iter()
234 .map(|msg| {
235 serde_json::json!({
236 "role": match msg.role {
237 Role::System => "system",
238 Role::User => "user",
239 Role::Assistant => "assistant",
240 },
241 "content": msg.content.as_text()
242 })
243 })
244 .collect();
245
246 let mut chat_body = serde_json::json!({
247 "model": request.model,
248 "messages": msgs,
249 });
250
251 if let Some(temp) = request.temperature {
253 chat_body["temperature"] =
254 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
255 }
256 if let Some(max_tokens) = request.max_tokens {
257 chat_body["max_tokens"] =
258 serde_json::Value::Number(serde_json::Number::from(max_tokens));
259 }
260
261 chat_body
262 }
263
264 fn parse_response(
265 &self,
266 response: serde_json::Value,
267 ) -> Result<ChatCompletionResponse, AiLibError> {
268 let content = if let Some(c) = response.get("choices") {
270 c[0]["message"]["content"]
272 .as_str()
273 .map(|s| s.to_string())
274 .or_else(|| c[0]["text"].as_str().map(|s| s.to_string()))
275 } else if let Some(msg) = response.get("message") {
276 msg.get("content").and_then(|content_array| {
278 content_array
279 .as_array()
280 .and_then(|arr| arr.first())
281 .and_then(|content_obj| {
282 content_obj
283 .get("text")
284 .and_then(|t| t.as_str())
285 .map(|text| text.to_string())
286 })
287 })
288 } else if let Some(gens) = response.get("generations") {
289 gens[0]["text"].as_str().map(|s| s.to_string())
291 } else {
292 None
293 };
294
295 let content = content.unwrap_or_default();
296
297 let mut function_call: Option<crate::types::function_call::FunctionCall> = None;
298 if let Some(fc_val) = response.get("function_call") {
299 if let Ok(fc) =
300 serde_json::from_value::<crate::types::function_call::FunctionCall>(fc_val.clone())
301 {
302 function_call = Some(fc);
303 } else if let Some(name) = fc_val
304 .get("name")
305 .and_then(|v| v.as_str())
306 .map(|s| s.to_string())
307 {
308 let args = fc_val.get("arguments").and_then(|a| {
309 if a.is_string() {
310 serde_json::from_str::<serde_json::Value>(a.as_str().unwrap()).ok()
311 } else {
312 Some(a.clone())
313 }
314 });
315 function_call = Some(crate::types::function_call::FunctionCall {
316 name,
317 arguments: args,
318 });
319 }
320 } else if let Some(tool_calls) = response.get("tool_calls").and_then(|v| v.as_array()) {
321 if let Some(first) = tool_calls.first() {
322 if let Some(func) = first.get("function").or_else(|| first.get("function_call")) {
323 if let Some(name) = func.get("name").and_then(|v| v.as_str()) {
324 let mut args_opt = func.get("arguments").cloned();
325 if let Some(args_val) = &args_opt {
326 if args_val.is_string() {
327 if let Some(s) = args_val.as_str() {
328 if let Ok(parsed) = serde_json::from_str::<serde_json::Value>(s) {
329 args_opt = Some(parsed);
330 }
331 }
332 }
333 }
334 function_call = Some(crate::types::function_call::FunctionCall {
335 name: name.to_string(),
336 arguments: args_opt,
337 });
338 }
339 }
340 }
341 }
342
343 let choice = Choice {
344 index: 0,
345 message: Message {
346 role: Role::Assistant,
347 content: crate::types::common::Content::Text(content.clone()),
348 function_call,
349 },
350 finish_reason: None,
351 };
352
353 let usage = Usage {
354 prompt_tokens: 0,
355 completion_tokens: 0,
356 total_tokens: 0,
357 };
358
359 Ok(ChatCompletionResponse {
360 id: response["id"].as_str().unwrap_or_default().to_string(),
361 object: response["object"].as_str().unwrap_or_default().to_string(),
362 created: response["created"].as_u64().unwrap_or(0),
363 model: response["model"].as_str().unwrap_or_default().to_string(),
364 choices: vec![choice],
365 usage,
366 })
367 }
368}
369
370#[async_trait::async_trait]
371impl ChatApi for CohereAdapter {
372 async fn chat_completion(
373 &self,
374 request: ChatCompletionRequest,
375 ) -> Result<ChatCompletionResponse, AiLibError> {
376 self.metrics.incr_counter("cohere.requests", 1).await;
377 let timer = self.metrics.start_timer("cohere.request_duration_ms").await;
378
379 let _body = self.convert_request(&request);
380
381 let url_generate = format!("{}/v1/generate", self.base_url);
383
384 let mut headers = HashMap::new();
385 headers.insert(
386 "Authorization".to_string(),
387 format!("Bearer {}", self.api_key),
388 );
389 headers.insert("Content-Type".to_string(), "application/json".to_string());
390 headers.insert("Accept".to_string(), "application/json".to_string());
391
392 let prompt = request
394 .messages
395 .iter()
396 .map(|msg| match msg.role {
397 Role::System => format!("System: {}", msg.content.as_text()),
398 Role::User => format!("Human: {}", msg.content.as_text()),
399 Role::Assistant => format!("Assistant: {}", msg.content.as_text()),
400 })
401 .collect::<Vec<_>>()
402 .join("\n");
403
404 let mut generate_body = serde_json::json!({
405 "model": request.model,
406 "prompt": prompt,
407 });
408
409 if let Some(temp) = request.temperature {
410 generate_body["temperature"] =
411 serde_json::Value::Number(serde_json::Number::from_f64(temp.into()).unwrap());
412 }
413 if let Some(max_tokens) = request.max_tokens {
414 generate_body["max_tokens"] =
415 serde_json::Value::Number(serde_json::Number::from(max_tokens));
416 }
417
418 let response_json = self
420 .transport
421 .post_json(&url_generate, Some(headers), generate_body)
422 .await?;
423
424 if let Some(t) = timer {
425 t.stop();
426 }
427
428 self.parse_response(response_json)
429 }
430
431 async fn chat_completion_stream(
432 &self,
433 _request: ChatCompletionRequest,
434 ) -> Result<
435 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
436 AiLibError,
437 > {
438 let mut stream_request = self.convert_request(&_request);
440 stream_request["stream"] = serde_json::Value::Bool(true);
441
442 let url = format!("{}/v1/chat", self.base_url);
443
444 let mut headers = HashMap::new();
445 headers.insert(
446 "Authorization".to_string(),
447 format!("Bearer {}", self.api_key),
448 );
449 if let Ok(byte_stream) = self
451 .transport
452 .post_stream(&url, Some(headers.clone()), stream_request.clone())
453 .await
454 {
455 let (tx, rx) = mpsc::unbounded_channel();
456 tokio::spawn(async move {
457 let mut buffer = Vec::new();
458 futures::pin_mut!(byte_stream);
459 while let Some(item) = byte_stream.next().await {
460 match item {
461 Ok(bytes) => {
462 buffer.extend_from_slice(&bytes);
463 #[cfg(feature = "unified_sse")]
464 {
465 while let Some(boundary) =
466 crate::sse::parser::find_event_boundary(&buffer)
467 {
468 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
469 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
470 if let Some(parsed) =
471 crate::sse::parser::parse_sse_event(event_text)
472 {
473 match parsed {
474 Ok(Some(chunk)) => {
475 if tx.send(Ok(chunk)).is_err() {
476 return;
477 }
478 }
479 Ok(None) => return,
480 Err(e) => {
481 let _ = tx.send(Err(e));
482 return;
483 }
484 }
485 }
486 }
487 }
488 }
489 #[cfg(not(feature = "unified_sse"))]
490 {
491 while let Some(boundary) = find_event_boundary(&buffer) {
492 let event_bytes = buffer.drain(..boundary).collect::<Vec<_>>();
493 if let Ok(event_text) = std::str::from_utf8(&event_bytes) {
494 if let Some(parsed) = parse_sse_event(event_text) {
495 match parsed {
496 Ok(Some(chunk)) => {
497 if tx.send(Ok(chunk)).is_err() {
498 return;
499 }
500 }
501 Ok(None) => return,
502 Err(e) => {
503 let _ = tx.send(Err(e));
504 return;
505 }
506 }
507 }
508 }
509 }
510 }
511 }
512 Err(e) => {
513 let _ = tx.send(Err(AiLibError::ProviderError(format!(
514 "Stream error: {}",
515 e
516 ))));
517 break;
518 }
519 }
520 }
521 });
522 let stream = UnboundedReceiverStream::new(rx);
523 return Ok(Box::new(Box::pin(stream)));
524 }
525
526 let finished = self.chat_completion(_request.clone()).await?;
528 let text = finished
529 .choices
530 .first()
531 .map(|c| c.message.content.as_text())
532 .unwrap_or_default();
533
534 let (tx, rx) = mpsc::unbounded_channel();
535
536 tokio::spawn(async move {
537 let chunks = split_text_into_chunks(&text, 80);
538 for chunk in chunks {
539 let delta = crate::api::ChoiceDelta {
541 index: 0,
542 delta: crate::api::MessageDelta {
543 role: Some(crate::types::Role::Assistant),
544 content: Some(chunk.clone()),
545 },
546 finish_reason: None,
547 };
548 let chunk_obj = ChatCompletionChunk {
549 id: "simulated".to_string(),
550 object: "chat.completion.chunk".to_string(),
551 created: 0,
552 model: finished.model.clone(),
553 choices: vec![delta],
554 };
555
556 if tx.send(Ok(chunk_obj)).is_err() {
557 return;
558 }
559 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
560 }
561 });
562
563 let stream = UnboundedReceiverStream::new(rx);
564 Ok(Box::new(Box::pin(stream)))
565 }
566
567 async fn list_models(&self) -> Result<Vec<String>, AiLibError> {
568 let url = format!("{}/v1/models", self.base_url);
570 let mut headers = HashMap::new();
571 headers.insert(
572 "Authorization".to_string(),
573 format!("Bearer {}", self.api_key),
574 );
575
576 let response = self.transport.get_json(&url, Some(headers)).await?;
577
578 Ok(response["models"]
579 .as_array()
580 .unwrap_or(&vec![])
581 .iter()
582 .filter_map(|m| {
583 m["id"]
584 .as_str()
585 .map(|s| s.to_string())
586 .or_else(|| m["name"].as_str().map(|s| s.to_string()))
587 })
588 .collect())
589 }
590
591 async fn get_model_info(&self, model_id: &str) -> Result<crate::api::ModelInfo, AiLibError> {
592 Ok(ModelInfo {
593 id: model_id.to_string(),
594 object: "model".to_string(),
595 created: 0,
596 owned_by: "cohere".to_string(),
597 permission: vec![ModelPermission {
598 id: "default".to_string(),
599 object: "model_permission".to_string(),
600 created: 0,
601 allow_create_engine: false,
602 allow_sampling: true,
603 allow_logprobs: false,
604 allow_search_indices: false,
605 allow_view: true,
606 allow_fine_tuning: false,
607 organization: "*".to_string(),
608 group: None,
609 is_blocking: false,
610 }],
611 })
612 }
613}