use crate::adapter::adapters::support::get_api_key;
use crate::adapter::cohere::CohereStreamer;
use crate::adapter::{Adapter, AdapterKind, ServiceType, WebRequestData};
use crate::chat::{
ChatOptionsSet, ChatRequest, ChatResponse, ChatRole, ChatStream, ChatStreamResponse, MessageContent, Usage,
};
use crate::resolver::{AuthData, Endpoint};
use crate::webc::{WebResponse, WebStream};
use crate::{Error, Headers, Result};
use crate::{ModelIden, ServiceTarget};
use reqwest::RequestBuilder;
use serde_json::{Value, json};
use value_ext::JsonValueExt;
pub struct CohereAdapter;
const MODELS: &[&str] = &[
"command-r-plus",
"command-r",
"command",
"command-nightly",
"command-light",
"command-light-nightly",
];
impl CohereAdapter {
pub const API_KEY_DEFAULT_ENV_NAME: &str = "COHERE_API_KEY";
}
impl Adapter for CohereAdapter {
const DEFAULT_API_KEY_ENV_NAME: Option<&'static str> = Some(Self::API_KEY_DEFAULT_ENV_NAME);
fn default_endpoint() -> Endpoint {
const BASE_URL: &str = "https://api.cohere.com/v1/";
Endpoint::from_static(BASE_URL)
}
fn default_auth() -> AuthData {
match Self::DEFAULT_API_KEY_ENV_NAME {
Some(env_name) => AuthData::from_env(env_name),
None => AuthData::None,
}
}
async fn all_model_names(_kind: AdapterKind) -> Result<Vec<String>> {
Ok(MODELS.iter().map(|s| s.to_string()).collect())
}
fn get_service_url(_model: &ModelIden, service_type: ServiceType, endpoint: Endpoint) -> Result<String> {
let base_url = endpoint.base_url();
let url = match service_type {
ServiceType::Chat | ServiceType::ChatStream => format!("{base_url}chat"),
ServiceType::Embed => {
let base_without_version = base_url.trim_end_matches("v1/");
format!("{base_without_version}v2/embed")
}
};
Ok(url)
}
fn to_web_request_data(
target: ServiceTarget,
service_type: ServiceType,
chat_req: ChatRequest,
options_set: ChatOptionsSet<'_, '_>,
) -> Result<WebRequestData> {
let ServiceTarget { endpoint, auth, model } = target;
let api_key = get_api_key(auth, &model)?;
let url = Self::get_service_url(&model, service_type, endpoint)?;
let headers = Headers::from(("Authorization".to_string(), format!("Bearer {api_key}")));
let CohereChatRequestParts {
preamble,
message,
chat_history,
} = Self::into_cohere_request_parts(model.clone(), chat_req)?;
let (_, model_name) = model.model_name.namespace_and_name();
let stream = matches!(service_type, ServiceType::ChatStream);
let mut payload = json!({
"model": model_name.to_string(),
"message": message,
"stream": stream
});
if !chat_history.is_empty() {
payload.x_insert("chat_history", chat_history)?;
}
if let Some(preamble) = preamble {
payload.x_insert("preamble", preamble)?;
}
if let Some(temperature) = options_set.temperature() {
payload.x_insert("temperature", temperature)?;
}
if !options_set.stop_sequences().is_empty() {
payload.x_insert("stop_sequences", options_set.stop_sequences())?;
}
if let Some(max_tokens) = options_set.max_tokens() {
payload.x_insert("max_tokens", max_tokens)?;
}
if let Some(top_p) = options_set.top_p() {
payload.x_insert("p", top_p)?;
}
Ok(WebRequestData { url, headers, payload })
}
fn to_chat_response(
model_iden: ModelIden,
web_response: WebResponse,
_options_set: ChatOptionsSet<'_, '_>,
) -> Result<ChatResponse> {
let WebResponse { mut body, .. } = web_response;
let provider_model_name = None;
let provider_model_iden = model_iden.from_optional_name(provider_model_name);
let usage = body.x_take("/meta/tokens").map(Self::into_usage).unwrap_or_default();
let Some(mut last_chat_history_item) = body.x_take::<Vec<Value>>("chat_history")?.pop() else {
return Err(Error::NoChatResponse { model_iden });
};
let content: MessageContent = last_chat_history_item
.x_take::<Option<String>>("message")?
.map(MessageContent::from)
.unwrap_or_default();
Ok(ChatResponse {
content,
reasoning_content: None,
model_iden,
provider_model_iden,
usage,
captured_raw_body: None, })
}
fn to_chat_stream(
model_iden: ModelIden,
reqwest_builder: RequestBuilder,
options_set: ChatOptionsSet<'_, '_>,
) -> Result<ChatStreamResponse> {
let web_stream = WebStream::new_with_delimiter(reqwest_builder, "\n");
let cohere_stream = CohereStreamer::new(web_stream, model_iden.clone(), options_set);
let chat_stream = ChatStream::from_inter_stream(cohere_stream);
Ok(ChatStreamResponse {
model_iden,
stream: chat_stream,
})
}
fn to_embed_request_data(
service_target: crate::ServiceTarget,
embed_req: crate::embed::EmbedRequest,
options_set: crate::embed::EmbedOptionsSet<'_, '_>,
) -> Result<crate::adapter::WebRequestData> {
super::embed::to_embed_request_data(service_target, embed_req, options_set)
}
fn to_embed_response(
model_iden: crate::ModelIden,
web_response: crate::webc::WebResponse,
options_set: crate::embed::EmbedOptionsSet<'_, '_>,
) -> Result<crate::embed::EmbedResponse> {
super::embed::to_embed_response(model_iden, web_response, options_set)
}
}
impl CohereAdapter {
pub(super) fn into_usage(mut usage_value: Value) -> Usage {
let prompt_tokens: Option<i32> = usage_value.x_take("input_tokens").ok();
let completion_tokens: Option<i32> = usage_value.x_take("output_tokens").ok();
let total_tokens = if prompt_tokens.is_some() || completion_tokens.is_some() {
Some(prompt_tokens.unwrap_or(0) + completion_tokens.unwrap_or(0))
} else {
None
};
#[allow(deprecated)]
Usage {
prompt_tokens,
prompt_tokens_details: None,
completion_tokens,
completion_tokens_details: None,
total_tokens,
}
}
fn into_cohere_request_parts(
model_iden: ModelIden, mut chat_req: ChatRequest,
) -> Result<CohereChatRequestParts> {
let mut chat_history: Vec<Value> = Vec::new();
let mut systems: Vec<String> = Vec::new();
if let Some(system) = chat_req.system {
systems.push(system);
}
let last_chat_msg = chat_req.messages.pop().ok_or_else(|| Error::ChatReqHasNoMessages {
model_iden: model_iden.clone(),
})?;
if !matches!(last_chat_msg.role, ChatRole::User) {
return Err(Error::LastChatMessageIsNotUser {
model_iden,
actual_role: last_chat_msg.role,
});
}
let Some(message) = last_chat_msg.content.into_joined_texts() else {
return Err(Error::MessageContentTypeNotSupported {
model_iden,
cause: "Only MessageContent::Text supported for this model (for now)",
});
};
for msg in chat_req.messages {
let Some(content) = msg.content.into_joined_texts() else {
return Err(Error::MessageContentTypeNotSupported {
model_iden,
cause: "Only MessageContent::Text supported for this model (for now)",
});
};
match msg.role {
ChatRole::System => systems.push(content),
ChatRole::User => chat_history.push(json! ({"role": "USER", "content": content})),
ChatRole::Assistant => chat_history.push(json! ({"role": "CHATBOT", "content": content})),
ChatRole::Tool => {
return Err(Error::MessageRoleNotSupported {
model_iden,
role: ChatRole::Tool,
});
}
}
}
let preamble = if !systems.is_empty() {
Some(systems.join("\n"))
} else {
None
};
Ok(CohereChatRequestParts {
preamble,
message,
chat_history,
})
}
}
struct CohereChatRequestParts {
preamble: Option<String>,
message: String,
chat_history: Vec<Value>,
}