use crate::{
client::ModelLister,
http_client::{self, HttpClientExt},
model::{Model, ModelList, ModelListingError},
providers::openai::Client,
};
use serde::Deserialize;
#[derive(Debug, Deserialize)]
struct ListModelsResponse {
data: Vec<ListModelEntry>,
}
#[derive(Debug, Deserialize)]
struct ListModelEntry {
id: String,
created: u64,
owned_by: String,
}
impl From<ListModelEntry> for Model {
fn from(value: ListModelEntry) -> Self {
let mut model = Model::from_id(value.id);
model.created_at = Some(value.created);
model.owned_by = Some(value.owned_by);
model
}
}
#[derive(Clone)]
pub struct OpenAIModelLister<H = reqwest::Client> {
client: Client<H>,
}
impl<H> ModelLister<H> for OpenAIModelLister<H>
where
H: HttpClientExt + Send + Sync + 'static,
{
type Client = Client<H>;
fn new(client: Self::Client) -> Self {
Self { client }
}
async fn list_all(&self) -> Result<ModelList, ModelListingError> {
let path = "/models";
let req = self.client.get(path)?.body(http_client::NoBody)?;
let response = self.client.send::<_, Vec<u8>>(req).await?;
if !response.status().is_success() {
let status_code = response.status().as_u16();
let body = response.into_body().await?;
return Err(ModelListingError::api_error_with_context(
"OpenAI",
path,
status_code,
&body,
));
}
let body = response.into_body().await?;
let api_resp: ListModelsResponse = serde_json::from_slice(&body).map_err(|error| {
ModelListingError::parse_error_with_context("OpenAI", path, &error, &body)
})?;
let models = api_resp.data.into_iter().map(Model::from).collect();
Ok(ModelList::new(models))
}
}