use crate::{
agent::AgentBuilder,
embeddings::{self},
extractor::ExtractorBuilder,
Embed,
};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use super::{completion::CompletionModel, embedding::EmbeddingModel, M2_BERT_80M_8K_RETRIEVAL};
const TOGETHER_AI_BASE_URL: &str = "https://api.together.xyz";
#[derive(Clone)]
pub struct Client {
base_url: String,
http_client: reqwest::Client,
}
impl Client {
pub fn new(api_key: &str) -> Self {
Self::from_url(api_key, TOGETHER_AI_BASE_URL)
}
fn from_url(api_key: &str, base_url: &str) -> Self {
Self {
base_url: base_url.to_string(),
http_client: reqwest::Client::builder()
.default_headers({
let mut headers = reqwest::header::HeaderMap::new();
headers.insert(
reqwest::header::CONTENT_TYPE,
"application/json".parse().unwrap(),
);
headers.insert(
"Authorization",
format!("Bearer {}", api_key)
.parse()
.expect("Bearer token should parse"),
);
headers
})
.build()
.expect("Together AI reqwest client should build"),
}
}
pub fn from_env() -> Self {
let api_key = std::env::var("TOGETHER_API_KEY").expect("TOGETHER_API_KEY not set");
Self::new(&api_key)
}
pub fn post(&self, path: &str) -> reqwest::RequestBuilder {
let url = format!("{}/{}", self.base_url, path).replace("//", "/");
tracing::debug!("POST {}", url);
self.http_client.post(url)
}
pub fn embedding_model(&self, model: &str) -> EmbeddingModel {
let ndims = match model {
M2_BERT_80M_8K_RETRIEVAL => 8192,
_ => 0,
};
EmbeddingModel::new(self.clone(), model, ndims)
}
pub fn embedding_model_with_ndims(&self, model: &str, ndims: usize) -> EmbeddingModel {
EmbeddingModel::new(self.clone(), model, ndims)
}
pub fn embeddings<D: Embed>(
&self,
model: &str,
) -> embeddings::EmbeddingsBuilder<EmbeddingModel, D> {
embeddings::EmbeddingsBuilder::new(self.embedding_model(model))
}
pub fn completion_model(&self, model: &str) -> CompletionModel {
CompletionModel::new(self.clone(), model)
}
pub fn agent(&self, model: &str) -> AgentBuilder<CompletionModel> {
AgentBuilder::new(self.completion_model(model))
}
pub fn extractor<T: JsonSchema + for<'a> Deserialize<'a> + Serialize + Send + Sync>(
&self,
model: &str,
) -> ExtractorBuilder<T, CompletionModel> {
ExtractorBuilder::new(self.completion_model(model))
}
}
pub mod together_ai_api_types {
use serde::Deserialize;
impl ApiErrorResponse {
pub fn message(&self) -> String {
format!("Code `{}`: {}", self.code, self.error)
}
}
#[derive(Debug, Deserialize)]
pub struct ApiErrorResponse {
pub error: String,
pub code: String,
}
#[derive(Debug, Deserialize)]
#[serde(untagged)]
pub enum ApiResponse<T> {
Ok(T),
Error(ApiErrorResponse),
}
}