use std::collections::HashMap;
use crate::{
agent::AgentBuilder,
completion::{self, CompletionError},
embeddings::{self, EmbeddingError, EmbeddingsBuilder},
extractor::ExtractorBuilder,
json_utils, Embed,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::json;
const COHERE_API_BASE_URL: &str = "https://api.cohere.ai";
#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}
impl Client {
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, COHERE_API_BASE_URL)
}
pub fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
headers
})
.build()
.expect("Cohere reqwest client should build"),
}
}
pub fn from_env() -> Self {
let api_key = std::env::var("COHERE_API_KEY").expect("COHERE_API_KEY not set");
Self::new(&api_key)
}
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
self.http_client.post(url)
}
pub fn embedding_model(&self, model: &str, input_type: &str) -> EmbeddingModel {
let ndims = match model {
EMBED_ENGLISH_V3 | EMBED_MULTILINGUAL_V3 | EMBED_ENGLISH_LIGHT_V2 => 1024,
EMBED_ENGLISH_LIGHT_V3 | EMBED_MULTILINGUAL_LIGHT_V3 => 384,
EMBED_ENGLISH_V2 => 4096,
EMBED_MULTILINGUAL_V2 => 768,
_ => 0,
};
EmbeddingModel::new(self.clone(), model, input_type, ndims)
}
pub fn embedding_model_with_ndims(
&self,
model: &str,
input_type: &str,
ndims: usize,
) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, input_type, ndims)
}
pub fn embeddings<D: Embed>(
&self,
model: &str,
input_type: &str,
) -> EmbeddingsBuilder<EmbeddingModel, D> {
EmbeddingsBuilder::new(self.embedding_model(model, input_type))
}
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}
#[derive(Debug, Deserialize)]
struct ApiErrorResponse {
message: String,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
enum ApiResponse<T> {
Ok(T),
Err(ApiErrorResponse),
}
pub const EMBED_ENGLISH_V3: &str = "embed-english-v3.0";
pub const EMBED_ENGLISH_LIGHT_V3: &str = "embed-english-light-v3.0";
pub const EMBED_MULTILINGUAL_V3: &str = "embed-multilingual-v3.0";
pub const EMBED_MULTILINGUAL_LIGHT_V3: &str = "embed-multilingual-light-v3.0";
pub const EMBED_ENGLISH_V2: &str = "embed-english-v2.0";
pub const EMBED_ENGLISH_LIGHT_V2: &str = "embed-english-light-v2.0";
pub const EMBED_MULTILINGUAL_V2: &str = "embed-multilingual-v2.0";
#[derive(Deserialize)]
pub struct EmbeddingResponse {
#[serde(default)]
pub response_type: Option<String>,
pub id: String,
pub embeddings: Vec<Vec<f64>>,
pub texts: Vec<String>,
#[serde(default)]
pub meta: Option<Meta>,
}
#[derive(Deserialize)]
pub struct Meta {
pub api_version: ApiVersion,
pub billed_units: BilledUnits,
#[serde(default)]
pub warnings: Vec<String>,
}
#[derive(Deserialize)]
pub struct ApiVersion {
pub version: String,
#[serde(default)]
pub is_deprecated: Option<bool>,
#[serde(default)]
pub is_experimental: Option<bool>,
}
#[derive(Deserialize, Debug)]
pub struct BilledUnits {
#[serde(default)]
pub input_tokens: u32,
#[serde(default)]
pub output_tokens: u32,
#[serde(default)]
pub search_units: u32,
#[serde(default)]
pub classifications: u32,
}
impl std::fmt::Display for BilledUnits {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(
f,
"Input tokens: {}\nOutput tokens: {}\nSearch units: {}\nClassifications: {}",
self.input_tokens, self.output_tokens, self.search_units, self.classifications
)
}
}
#[derive(Clone)]
pub struct EmbeddingModel {
client: Client,
pub model: String,
pub input_type: String,
ndims: usize,
}
impl embeddings::EmbeddingModel for EmbeddingModel {
const MAX_DOCUMENTS: usize = 96;
fn ndims(&self) -> usize {
self.ndims
}
#[cfg_attr(feature = "worker", worker::send)]
async fn embed_texts(
&self,
documents: impl IntoIterator<Item = String>,
) -> Result<Vec<embeddings::Embedding>, EmbeddingError> {
let documents = documents.into_iter().collect::<Vec<_>>();
let response = self
.client
.post("/v1/embed")
.json(&json!({
"model": self.model,
"texts": documents,
"input_type": self.input_type,
}))
.send()
.await?;
if response.status().is_success() {
match response.json::<ApiResponse<EmbeddingResponse>>().await? {
ApiResponse::Ok(response) => {
match response.meta {
Some(meta) => tracing::info!(target: "rig",
"Cohere embeddings billed units: {}",
meta.billed_units,
),
None => tracing::info!(target: "rig",
"Cohere embeddings billed units: n/a",
),
};
if response.embeddings.len() != documents.len() {
return Err(EmbeddingError::DocumentError(
format!(
"Expected {} embeddings, got {}",
documents.len(),
response.embeddings.len()
)
.into(),
));
}
Ok(response
.embeddings
.into_iter()
.zip(documents.into_iter())
.map(|(embedding, document)| embeddings::Embedding {
document,
vec: embedding,
})
.collect())
}
ApiResponse::Err(error) => Err(EmbeddingError::ProviderError(error.message)),
}
} else {
Err(EmbeddingError::ProviderError(response.text().await?))
}
}
}
impl EmbeddingModel {
pub fn new(client: Client, model: &str, input_type: &str, ndims: usize) -> Self {
Self {
client,
model: model.to_string(),
input_type: input_type.to_string(),
ndims,
}
}
}
pub const COMMAND_R_PLUS: &str = "comman-r-plus";
pub const COMMAND_R: &str = "command-r";
pub const COMMAND: &str = "command";
pub const COMMAND_NIGHTLY: &str = "command-nightly";
pub const COMMAND_LIGHT: &str = "command-light";
pub const COMMAND_LIGHT_NIGHTLY: &str = "command-light-nightly";
#[derive(Debug, Deserialize)]
pub struct CompletionResponse {
pub text: String,
pub generation_id: String,
#[serde(default)]
pub citations: Vec<Citation>,
#[serde(default)]
pub documents: Vec<Document>,
#[serde(default)]
pub is_search_required: Option<bool>,
#[serde(default)]
pub search_queries: Vec<SearchQuery>,
#[serde(default)]
pub search_results: Vec<SearchResult>,
pub finish_reason: String,
#[serde(default)]
pub tool_calls: Vec<ToolCall>,
#[serde(default)]
pub chat_history: Vec<ChatHistory>,
}
impl From<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
fn from(response: CompletionResponse) -> Self {
let CompletionResponse {
text, tool_calls, ..
} = &response;
let model_response = if !tool_calls.is_empty() {
completion::ModelChoice::ToolCall(
tool_calls.first().unwrap().name.clone(),
"".to_owned(),
tool_calls.first().unwrap().parameters.clone(),
)
} else {
completion::ModelChoice::Message(text.clone())
};
completion::CompletionResponse {
choice: model_response,
raw_response: response,
}
}
}
#[derive(Debug, Deserialize)]
pub struct Citation {
pub start: u32,
pub end: u32,
pub text: String,
pub document_ids: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct Document {
pub id: String,
#[serde(flatten)]
pub additional_prop: HashMap<String, serde_json::Value>,
}
#[derive(Debug, Deserialize)]
pub struct SearchQuery {
pub text: String,
pub generation_id: String,
}
#[derive(Debug, Deserialize)]
pub struct SearchResult {
pub search_query: SearchQuery,
pub connector: Connector,
pub document_ids: Vec<String>,
#[serde(default)]
pub error_message: Option<String>,
#[serde(default)]
pub continue_on_failure: bool,
}
#[derive(Debug, Deserialize)]
pub struct Connector {
pub id: String,
}
#[derive(Debug, Deserialize)]
pub struct ToolCall {
pub name: String,
pub parameters: serde_json::Value,
}
#[derive(Debug, Deserialize)]
pub struct ChatHistory {
pub role: String,
pub message: String,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct Parameter {
pub description: String,
pub r#type: String,
pub required: bool,
}
#[derive(Debug, Deserialize, Serialize)]
pub struct ToolDefinition {
pub name: String,
pub description: String,
pub parameter_definitions: HashMap<String, Parameter>,
}
impl From<completion::ToolDefinition> for ToolDefinition {
fn from(tool: completion::ToolDefinition) -> Self {
fn convert_type(r#type: &serde_json::Value) -> String {
fn convert_type_str(r#type: &str) -> String {
match r#type {
"string" => "string".to_owned(),
"number" => "number".to_owned(),
"integer" => "integer".to_owned(),
"boolean" => "boolean".to_owned(),
"array" => "array".to_owned(),
"object" => "object".to_owned(),
_ => "string".to_owned(),
}
}
match r#type {
serde_json::Value::String(r#type) => convert_type_str(r#type.as_str()),
serde_json::Value::Array(types) => convert_type_str(
types
.iter()
.find(|t| t.as_str() != Some("null"))
.and_then(|t| t.as_str())
.unwrap_or("string"),
),
_ => "string".to_owned(),
}
}
let maybe_required = tool
.parameters
.get("required")
.and_then(|v| v.as_array())
.map(|required| {
required
.iter()
.filter_map(|v| v.as_str())
.collect::<Vec<_>>()
})
.unwrap_or_default();
Self {
name: tool.name,
description: tool.description,
parameter_definitions: tool
.parameters
.get("properties")
.expect("Tool properties should exist")
.as_object()
.expect("Tool properties should be an object")
.iter()
.map(|(argname, argdef)| {
(
argname.clone(),
Parameter {
description: argdef
.get("description")
.expect("Argument description should exist")
.as_str()
.expect("Argument description should be a string")
.to_string(),
r#type: convert_type(
argdef.get("type").expect("Argument type should exist"),
),
required: maybe_required.contains(&argname.as_str()),
},
)
})
.collect::<HashMap<_, _>>(),
}
}
}
#[derive(Deserialize, Serialize)]
pub struct Message {
pub role: String,
pub message: String,
}
impl From<completion::Message> for Message {
fn from(message: completion::Message) -> Self {
Self {
role: match message.role.as_str() {
"system" => "SYSTEM".to_owned(),
"user" => "USER".to_owned(),
"assistant" => "CHATBOT".to_owned(),
_ => "USER".to_owned(),
},
message: message.content,
}
}
}
#[derive(Clone)]
pub struct CompletionModel {
client: Client,
pub model: String,
}
impl CompletionModel {
pub fn new(client: Client, model: &str) -> Self {
Self {
client,
model: model.to_string(),
}
}
}
impl completion::CompletionModel for CompletionModel {
type Response = CompletionResponse;
#[cfg_attr(feature = "worker", worker::send)]
async fn completion(
&self,
completion_request: completion::CompletionRequest,
) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
let request = json!({
"model": self.model,
"preamble": completion_request.preamble,
"message": completion_request.prompt,
"documents": completion_request.documents,
"chat_history": completion_request.chat_history.into_iter().map(Message::from).collect::<Vec<_>>(),
"temperature": completion_request.temperature,
"tools": completion_request.tools.into_iter().map(ToolDefinition::from).collect::<Vec<_>>(),
});
let response = self
.client
.post("/v1/chat")
.json(
&if let Some(ref params) = completion_request.additional_params {
json_utils::merge(request.clone(), params.clone())
} else {
request.clone()
},
)
.send()
.await?;
if response.status().is_success() {
match response.json::<ApiResponse<CompletionResponse>>().await? {
ApiResponse::Ok(completion) => Ok(completion.into()),
ApiResponse::Err(error) => Err(CompletionError::ProviderError(error.message)),
}
} else {
Err(CompletionError::ProviderError(response.text().await?))
}
}
}