use serde::{Deserialize, Serialize};
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("transport error: {0}")]
Transport(#[from] reqwest::Error),
#[error("api error {status}: {message}")]
Api { status: u16, message: String },
#[error("config error: {0}")]
Config(String),
}
pub type Result<T> = std::result::Result<T, Error>;
pub const DEFAULT_BASE_URL: &str = "http://localhost:36900";
#[derive(Clone)]
pub struct Client {
base_url: String,
api_key: Option<String>,
http: reqwest::blocking::Client,
}
impl Client {
pub fn new(base_url: impl Into<String>) -> Self {
Client {
base_url: base_url.into(),
api_key: None,
http: reqwest::blocking::Client::new(),
}
}
pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
self.api_key = Some(key.into());
self
}
pub fn from_env() -> Self {
let base = std::env::var("HANZO_BASE_URL").unwrap_or_else(|_| DEFAULT_BASE_URL.to_string());
let mut c = Client::new(base);
if let Ok(k) = std::env::var("HANZO_API_KEY") {
c = c.with_api_key(k);
}
c
}
pub fn engine(&self) -> Engine<'_> {
Engine { client: self }
}
fn post<B: Serialize, R: for<'de> Deserialize<'de>>(&self, path: &str, body: &B) -> Result<R> {
let mut req = self.http.post(format!("{}{}", self.base_url, path)).json(body);
if let Some(k) = &self.api_key {
req = req.bearer_auth(k);
}
self.send(req)
}
fn get<R: for<'de> Deserialize<'de>>(&self, path: &str) -> Result<R> {
let mut req = self.http.get(format!("{}{}", self.base_url, path));
if let Some(k) = &self.api_key {
req = req.bearer_auth(k);
}
self.send(req)
}
fn send<R: for<'de> Deserialize<'de>>(&self, req: reqwest::blocking::RequestBuilder) -> Result<R> {
let resp = req.send()?;
let status = resp.status();
if !status.is_success() {
return Err(Error::Api {
status: status.as_u16(),
message: resp.text().unwrap_or_default(),
});
}
Ok(resp.json()?)
}
}
pub struct Engine<'a> {
client: &'a Client,
}
impl Engine<'_> {
pub fn chat(&self, model: &str, messages: &[Message]) -> Result<ChatResponse> {
self.client.post("/v1/chat/completions", &ChatRequest { model, messages })
}
pub fn embeddings(&self, model: &str, input: Vec<String>) -> Result<EmbeddingsResponse> {
self.client.post("/v1/embeddings", &EmbeddingsRequest { model, input })
}
pub fn models(&self) -> Result<ModelList> {
self.client.get("/v1/models")
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Message {
pub role: String,
pub content: String,
}
pub fn msg(role: &str, content: impl Into<String>) -> Message {
Message { role: role.to_string(), content: content.into() }
}
#[derive(Serialize)]
struct ChatRequest<'a> {
model: &'a str,
messages: &'a [Message],
}
#[derive(Debug, Deserialize)]
pub struct ChatResponse {
pub choices: Vec<ChatChoice>,
}
impl ChatResponse {
pub fn text(&self) -> &str {
self.choices.first().map(|c| c.message.content.as_str()).unwrap_or("")
}
}
#[derive(Debug, Deserialize)]
pub struct ChatChoice {
pub message: Message,
}
#[derive(Serialize)]
struct EmbeddingsRequest<'a> {
model: &'a str,
input: Vec<String>,
}
#[derive(Debug, Deserialize)]
pub struct EmbeddingsResponse {
pub data: Vec<Embedding>,
}
#[derive(Debug, Deserialize)]
pub struct Embedding {
pub embedding: Vec<f32>,
}
#[derive(Debug, Deserialize)]
pub struct ModelList {
pub data: Vec<Model>,
}
#[derive(Debug, Deserialize)]
pub struct Model {
pub id: String,
}