use async_trait::async_trait;
use serde::Deserialize;
use tracing::{debug, error};
use crate::error::BaochuanError;
use crate::provider::{ChunkStream, Provider};
use crate::providers::openai_compat::OpenAICompatClient;
use crate::types::{ChatRequest, ChatResponse, ModelInfo};
const DEFAULT_BASE_URL: &str = "http://localhost:1234/api/v0";
#[derive(Deserialize)]
struct LmsModelList {
data: Vec<LmsModel>,
}
#[derive(Deserialize)]
#[serde(rename_all = "camelCase")]
struct LmsModel {
id: String,
publisher: Option<String>,
arch: Option<String>,
quantization: Option<String>,
max_context_length: Option<u32>,
}
pub struct LmStudioProvider {
inner: OpenAICompatClient,
}
impl LmStudioProvider {
pub fn new() -> Self {
Self { inner: OpenAICompatClient::no_key(DEFAULT_BASE_URL) }
}
pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
let b = base_url.into();
self.inner.base_url = format!("{}/api/v0", b.trim_end_matches('/'));
self
}
pub fn with_api_key(mut self, key: impl Into<String>) -> Self {
self.inner.api_key = Some(key.into());
self
}
}
impl Default for LmStudioProvider {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl Provider for LmStudioProvider {
fn name(&self) -> &str {
"lmstudio"
}
async fn models(&self) -> Result<Vec<ModelInfo>, BaochuanError> {
debug!("listing models from LM Studio native API");
let response = self
.inner
.auth(self.inner.client.get(self.inner.models_url()))
.send()
.await?;
let status = response.status();
if !status.is_success() {
let body = response.text().await.unwrap_or_default();
error!(status = %status, body = %body, "LM Studio models error");
return Err(BaochuanError::Api { status: status.as_u16(), message: body });
}
let list: LmsModelList = response.json().await?;
Ok(list.data.into_iter().map(|m| {
let display = match (&m.arch, &m.quantization) {
(Some(arch), Some(quant)) => Some(format!("{arch} · {quant}")),
(Some(arch), None) => Some(arch.clone()),
_ => None,
};
ModelInfo {
id: m.id,
owned_by: m.publisher,
context_length: m.max_context_length,
display_name: display,
}
}).collect())
}
async fn chat(&self, request: &ChatRequest) -> Result<ChatResponse, BaochuanError> {
self.inner.chat(request, self.name()).await
}
async fn stream_chat(&self, request: &ChatRequest) -> Result<ChunkStream, BaochuanError> {
self.inner.stream_chat(request, self.name()).await
}
}