use anyhow::{Context, Result};
use async_trait::async_trait;
use reqwest::Client;
use serde::Deserialize;
use crate::model_listing::{
AvailableModel, ModelLister, OpenAIListResponse, infer_openai_capabilities,
};
#[derive(Debug, Deserialize, Clone)]
pub struct OpenAIListModelsResponse {
pub data: Vec<OpenAIModelEntry>,
}
#[derive(Debug, Deserialize, Clone)]
pub struct OpenAIModelEntry {
pub id: String,
pub owned_by: Option<String>,
pub created: Option<i64>,
}
const OPENAI_MODELS_LIST_URL: &str = "https://api.openai.com/v1/models";
pub struct OpenAIModelLister {
api_key: String,
base_url: String,
http_client: Client,
}
impl OpenAIModelLister {
pub fn new(api_key: String, base_url: Option<String>) -> Self {
Self {
api_key,
base_url: base_url.unwrap_or_else(|| OPENAI_MODELS_LIST_URL.to_string()),
http_client: Client::new(),
}
}
}
#[async_trait]
impl ModelLister for OpenAIModelLister {
async fn list_models(&self) -> Result<Vec<AvailableModel>> {
let models_url = if self.base_url.ends_with("/chat/completions") {
self.base_url.replace("/chat/completions", "/models")
} else {
self.base_url.clone()
};
let resp = self
.http_client
.get(&models_url)
.header("Authorization", format!("Bearer {}", self.api_key))
.send()
.await
.context("Failed to list OpenAI models")?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(anyhow::anyhow!(
"OpenAI models API returned {}: {}",
status,
body
));
}
let list: OpenAIListResponse = resp
.json()
.await
.context("Failed to parse OpenAI models response")?;
let models = list
.data
.into_iter()
.map(|entry| AvailableModel {
id: entry.id.clone(),
display_name: None,
provider: crate::ProviderType::OpenAI,
capabilities: infer_openai_capabilities(&entry.id),
owned_by: entry.owned_by,
context_window: None,
max_output_tokens: None,
created_at: entry.created,
})
.collect();
Ok(models)
}
}