1use std::collections::HashMap;
12
13use crate::{
14 agent::AgentBuilder,
15 completion::{self, CompletionError},
16 embeddings::{self, EmbeddingError, EmbeddingsBuilder},
17 extractor::ExtractorBuilder,
18 json_utils, message, Embed, OneOrMany,
19};
20
21use schemars::JsonSchema;
22use serde::{Deserialize, Serialize};
23use serde_json::json;
24
25const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
29
30#[derive(Clone)]
31pub struct Client {
32 base_url: String,
33 http_client: reqwest::Client,
34}
35
36impl Client {
37 pub fn new(api_key: &str) -> Self {
38 Self::from_url(api_key, COHERE_API_BASE_URL)
39 }
40
41 pub fn from_url(api_key: &str, base_url: &str) -> Self {
42 Self {
43 base_url: base_url.to_string(),
44 http_client: reqwest::Client::builder()
45 .default_headers({
46 let mut headers = reqwest::header::HeaderMap::new();
47 headers.insert(
48 "Authorization",
49 format!("Bearer {}", api_key)
50 .parse()
51 .expect("Bearer token should parse"),
52 );
53 headers
54 })
55 .build()
56 .expect("Cohere reqwest client should build"),
57 }
58 }
59
60 pub fn from_env() -> Self {
63 let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
64 Self::new(&api_key)
65 }
66
67 pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
68 let url = format!("{}/{}", self.base_url, path).replace("//", "/");
69 self.http_client.post(url)
70 }
71
72 pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel {
75 let ndims = match model {
76 EMBED_ENGLISH_V3 | EMBED_MULTILINGUAL_V3 | EMBED_ENGLISH_LIGHT_V2 => 1024,
77 EMBED_ENGLISH_LIGHT_V3 | EMBED_MULTILINGUAL_LIGHT_V3 => 384,
78 EMBED_ENGLISH_V2 => 4096,
79 EMBED_MULTILINGUAL_V2 => 768,
80 _ => 0,
81 };
82 EmbeddingModel::new(self.clone(), model, input_type, ndims)
83 }
84
85 pub fn embedding_model_with_ndims(
87 &self,
88 model: &str,
89 input_type: &str,
90 ndims: usize,
91 ) -> EmbeddingModel {
92 EmbeddingModel::new(self.clone(), model, input_type, ndims)
93 }
94
95 pub fn embeddings<D: Embed>(
96 &self,
97 model: &str,
98 input_type: &str,
99 ) -> EmbeddingsBuilder<EmbeddingModel, D> {
100 EmbeddingsBuilder::new(self.embedding_model(model, input_type))
101 }
102
103 pub fn completion_model(&self, model: &str) -> CompletionModel {
104 CompletionModel::new(self.clone(), model)
105 }
106
107 pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
108 AgentBuilder::new(self.completion_model(model))
109 }
110
111 pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
112 &self,
113 model: &str,
114 ) -> ExtractorBuilder<T, CompletionModel> {
115 ExtractorBuilder::new(self.completion_model(model))
116 }
117}
118
119#[derive(Debug, Deserialize)]
120struct ApiErrorResponse {
121 message: String,
122}
123
124#[derive(Debug, Deserialize)]
125#[serde(untagged)]
126enum ApiResponse<T> {
127 Ok(T),
128 Err(ApiErrorResponse),
129}
130
131pub const EMBED_ENGLISH_V3: &str = "embed-english-v3.0";
136pub const EMBED_ENGLISH_LIGHT_V3: &str = "embed-english-light-v3.0";
138pub const EMBED_MULTILINGUAL_V3: &str = "embed-multilingual-v3.0";
140pub const EMBED_MULTILINGUAL_LIGHT_V3: &str = "embed-multilingual-light-v3.0";
142pub const EMBED_ENGLISH_V2: &str = "embed-english-v2.0";
144pub const EMBED_ENGLISH_LIGHT_V2: &str = "embed-english-light-v2.0";
146pub const EMBED_MULTILINGUAL_V2: &str = "embed-multilingual-v2.0";
148
149#[derive(Deserialize)]
150pub struct EmbeddingResponse {
151 #[serde(default)]
152 pub response_type: Option<String>,
153 pub id: String,
154 pub embeddings: Vec<Vec<f64>>,
155 pub texts: Vec<String>,
156 #[serde(default)]
157 pub meta: Option<Meta>,
158}
159
160#[derive(Deserialize)]
161pub struct Meta {
162 pub api_version: ApiVersion,
163 pub billed_units: BilledUnits,
164 #[serde(default)]
165 pub warnings: Vec<String>,
166}
167
168#[derive(Deserialize)]
169pub struct ApiVersion {
170 pub version: String,
171 #[serde(default)]
172 pub is_deprecated: Option<bool>,
173 #[serde(default)]
174 pub is_experimental: Option<bool>,
175}
176
177#[derive(Deserialize, Debug)]
178pub struct BilledUnits {
179 #[serde(default)]
180 pub input_tokens: u32,
181 #[serde(default)]
182 pub output_tokens: u32,
183 #[serde(default)]
184 pub search_units: u32,
185 #[serde(default)]
186 pub classifications: u32,
187}
188
189impl std::fmt::Display for BilledUnits {
190 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
191 write!(
192 f,
193 "Input tokens: {}\nOutput tokens: {}\nSearch units: {}\nClassifications: {}",
194 self.input_tokens, self.output_tokens, self.search_units, self.classifications
195 )
196 }
197}
198
199#[derive(Clone)]
200pub struct EmbeddingModel {
201 client: Client,
202 pub model: String,
203 pub input_type: String,
204 ndims: usize,
205}
206
207impl embeddings::EmbeddingModel for EmbeddingModel {
208 const MAX_DOCUMENTS: usize = 96;
209
210 fn ndims(&self) -> usize {
211 self.ndims
212 }
213
214 #[cfg_attr(feature = "worker", worker::send)]
215 async fn embed_texts(
216 &self,
217 documents: impl IntoIterator<Item = String>,
218 ) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
219 let documents = documents.into_iter().collect::<Vec<_>>();
220
221 let response = self
222 .client
223 .post("/v1/embed")
224 .json(&json!({
225 "model": self.model,
226 "texts": documents,
227 "input_type": self.input_type,
228 }))
229 .send()
230 .await?;
231
232 if response.status().is_success() {
233 match response.json::<ApiResponse<EmbeddingResponse>>().await? {
234 ApiResponse::Ok(response) => {
235 match response.meta {
236 Some(meta) => tracing::info!(target: "rig",
237 "Cohere embeddings billed units: {}",
238 meta.billed_units,
239 ),
240 None => tracing::info!(target: "rig",
241 "Cohere embeddings billed units: n/a",
242 ),
243 };
244
245 if response.embeddings.len() != documents.len() {
246 return Err(EmbeddingError::DocumentError(
247 format!(
248 "Expected {} embeddings, got {}",
249 documents.len(),
250 response.embeddings.len()
251 )
252 .into(),
253 ));
254 }
255
256 Ok(response
257 .embeddings
258 .into_iter()
259 .zip(documents.into_iter())
260 .map(|(embedding, document)| embeddings::Embedding {
261 document,
262 vec: embedding,
263 })
264 .collect())
265 }
266 ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
267 }
268 } else {
269 Err(EmbeddingError::ProviderError(response.text().await?))
270 }
271 }
272}
273
274impl EmbeddingModel {
275 pub fn new(client: Client, model: &str, input_type: &str, ndims: usize) -> Self {
276 Self {
277 client,
278 model: model.to_string(),
279 input_type: input_type.to_string(),
280 ndims,
281 }
282 }
283}
284
285pub const COMMAND_R_PLUS: &str = "comman-r-plus";
290pub const COMMAND_R: &str = "command-r";
292pub const COMMAND: &str = "command";
294pub const COMMAND_NIGHTLY: &str = "command-nightly";
296pub const COMMAND_LIGHT: &str = "command-light";
298pub const COMMAND_LIGHT_NIGHTLY: &str = "command-light-nightly";
300
301#[derive(Debug, Deserialize)]
302pub struct CompletionResponse {
303 pub text: String,
304 pub generation_id: String,
305 #[serde(default)]
306 pub citations: Vec<Citation>,
307 #[serde(default)]
308 pub documents: Vec<Document>,
309 #[serde(default)]
310 pub is_search_required: Option<bool>,
311 #[serde(default)]
312 pub search_queries: Vec<SearchQuery>,
313 #[serde(default)]
314 pub search_results: Vec<SearchResult>,
315 pub finish_reason: String,
316 #[serde(default)]
317 pub tool_calls: Vec<ToolCall>,
318 #[serde(default)]
319 pub chat_history: Vec<ChatHistory>,
320}
321
322impl From<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
323 fn from(response: CompletionResponse) -> Self {
324 let CompletionResponse {
325 text, tool_calls, ..
326 } = &response;
327
328 let model_response = if !tool_calls.is_empty() {
329 tool_calls
330 .iter()
331 .map(|tool_call| {
332 completion::AssistantContent::tool_call(
333 tool_call.name.clone(),
334 tool_call.name.clone(),
335 tool_call.parameters.clone(),
336 )
337 })
338 .collect::<Vec<_>>()
339 } else {
340 vec![completion::AssistantContent::text(text.clone())]
341 };
342
343 completion::CompletionResponse {
344 choice: OneOrMany::many(model_response).expect("There is atleast one content"),
345 raw_response: response,
346 }
347 }
348}
349
350#[derive(Debug, Deserialize)]
351pub struct Citation {
352 pub start: u32,
353 pub end: u32,
354 pub text: String,
355 pub document_ids: Vec<String>,
356}
357
358#[derive(Debug, Deserialize)]
359pub struct Document {
360 pub id: String,
361 #[serde(flatten)]
362 pub additional_prop: HashMap<String, serde_json::Value>,
363}
364
365#[derive(Debug, Deserialize)]
366pub struct SearchQuery {
367 pub text: String,
368 pub generation_id: String,
369}
370
371#[derive(Debug, Deserialize)]
372pub struct SearchResult {
373 pub search_query: SearchQuery,
374 pub connector: Connector,
375 pub document_ids: Vec<String>,
376 #[serde(default)]
377 pub error_message: Option<String>,
378 #[serde(default)]
379 pub continue_on_failure: bool,
380}
381
382#[derive(Debug, Deserialize)]
383pub struct Connector {
384 pub id: String,
385}
386
387#[derive(Debug, Deserialize, Serialize)]
388pub struct ToolCall {
389 pub name: String,
390 pub parameters: serde_json::Value,
391}
392
393#[derive(Debug, Deserialize)]
394pub struct ChatHistory {
395 pub role: String,
396 pub message: String,
397}
398
399#[derive(Debug, Deserialize, Serialize)]
400pub struct Parameter {
401 pub description: String,
402 pub r#type: String,
403 pub required: bool,
404}
405
406#[derive(Debug, Deserialize, Serialize)]
407pub struct ToolDefinition {
408 pub name: String,
409 pub description: String,
410 pub parameter_definitions: HashMap<String, Parameter>,
411}
412
413impl From<completion::ToolDefinition> for ToolDefinition {
414 fn from(tool: completion::ToolDefinition) -> Self {
415 fn convert_type(r#type: &serde_json::Value) -> String {
416 fn convert_type_str(r#type: &str) -> String {
417 match r#type {
418 "string" => "string".to_owned(),
419 "number" => "number".to_owned(),
420 "integer" => "integer".to_owned(),
421 "boolean" => "boolean".to_owned(),
422 "array" => "array".to_owned(),
423 "object" => "object".to_owned(),
424 _ => "string".to_owned(),
425 }
426 }
427 match r#type {
428 serde_json::Value::String(r#type) => convert_type_str(r#type.as_str()),
429 serde_json::Value::Array(types) => convert_type_str(
430 types
431 .iter()
432 .find(|t| t.as_str() != Some("null"))
433 .and_then(|t| t.as_str())
434 .unwrap_or("string"),
435 ),
436 _ => "string".to_owned(),
437 }
438 }
439
440 let maybe_required = tool
441 .parameters
442 .get("required")
443 .and_then(|v| v.as_array())
444 .map(|required| {
445 required
446 .iter()
447 .filter_map(|v| v.as_str())
448 .collect::<Vec<_>>()
449 })
450 .unwrap_or_default();
451
452 Self {
453 name: tool.name,
454 description: tool.description,
455 parameter_definitions: tool
456 .parameters
457 .get("properties")
458 .expect("Tool properties should exist")
459 .as_object()
460 .expect("Tool properties should be an object")
461 .iter()
462 .map(|(argname, argdef)| {
463 (
464 argname.clone(),
465 Parameter {
466 description: argdef
467 .get("description")
468 .expect("Argument description should exist")
469 .as_str()
470 .expect("Argument description should be a string")
471 .to_string(),
472 r#type: convert_type(
473 argdef.get("type").expect("Argument type should exist"),
474 ),
475 required: maybe_required.contains(&argname.as_str()),
476 },
477 )
478 })
479 .collect::<HashMap<_, _>>(),
480 }
481 }
482}
483
484#[derive(Deserialize, Serialize)]
485#[serde(tag = "role", rename_all = "UPPERCASE")]
486pub enum Message {
487 User {
488 message: String,
489 tool_calls: Vec<ToolCall>,
490 },
491
492 Chatbot {
493 message: String,
494 tool_calls: Vec<ToolCall>,
495 },
496
497 Tool {
498 tool_results: Vec<ToolResult>,
499 },
500
501 System {
503 content: String,
504 tool_calls: Vec<ToolCall>,
505 },
506}
507
508#[derive(Deserialize, Serialize)]
509pub struct ToolResult {
510 pub call: ToolCall,
511 pub outputs: Vec<serde_json::Value>,
512}
513
514impl TryFrom<message::Message> for Vec<Message> {
515 type Error = message::MessageError;
516
517 fn try_from(message: message::Message) -> Result<Self, Self::Error> {
518 match message {
519 message::Message::User { content } => content
520 .into_iter()
521 .map(|content| {
522 Ok(Message::User {
523 message: match content {
524 message::UserContent::Text(message::Text { text }) => text,
525 _ => {
526 return Err(message::MessageError::ConversionError(
527 "Only text content is supported by Cohere".to_owned(),
528 ))
529 }
530 },
531 tool_calls: vec![],
532 })
533 })
534 .collect::<Result<Vec<_>, _>>(),
535 _ => Err(message::MessageError::ConversionError(
536 "Only user messages are supported by Cohere".to_owned(),
537 )),
538 }
539 }
540}
541
542#[derive(Clone)]
543pub struct CompletionModel {
544 client: Client,
545 pub model: String,
546}
547
548impl CompletionModel {
549 pub fn new(client: Client, model: &str) -> Self {
550 Self {
551 client,
552 model: model.to_string(),
553 }
554 }
555}
556
557impl completion::CompletionModel for CompletionModel {
558 type Response = CompletionResponse;
559
560 #[cfg_attr(feature = "worker", worker::send)]
561 async fn completion(
562 &self,
563 completion_request: completion::CompletionRequest,
564 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
565 let chat_history = completion_request
566 .chat_history
567 .into_iter()
568 .map(Vec::<Message>::try_from)
569 .collect::<Result<Vec<Vec<_>>, _>>()?
570 .into_iter()
571 .flatten()
572 .collect::<Vec<_>>();
573
574 let message = match completion_request.prompt {
575 message::Message::User { content } => Ok(content
576 .into_iter()
577 .map(|content| match content {
578 message::UserContent::Text(message::Text { text }) => Ok(text),
579 _ => Err(CompletionError::RequestError(
580 "Only text content is supported by Cohere".into(),
581 )),
582 })
583 .collect::<Result<Vec<_>, _>>()?
584 .join("\n")),
585
586 _ => Err(CompletionError::RequestError(
587 "Only user messages are supported by Cohere".into(),
588 )),
589 }?;
590
591 let request = json!({
592 "model": self.model,
593 "preamble": completion_request.preamble,
594 "message": message,
595 "documents": completion_request.documents,
596 "chat_history": chat_history,
597 "temperature": completion_request.temperature,
598 "tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
599 });
600
601 let response = self
602 .client
603 .post("/v1/chat")
604 .json(
605 &if let Some(ref params) = completion_request.additional_params {
606 json_utils::merge(request.clone(), params.clone())
607 } else {
608 request.clone()
609 },
610 )
611 .send()
612 .await?;
613
614 if response.status().is_success() {
615 match response.json::<ApiResponse<CompletionResponse>>().await? {
616 ApiResponse::Ok(completion) => Ok(completion.into()),
617 ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
618 }
619 } else {
620 Err(CompletionError::ProviderError(response.text().await?))
621 }
622 }
623}