use std::fmt;
use crate::error::LlmError;
use crate::openai::OpenAiProvider;
use crate::provider::{
ChatResponse, ChatStream, GenerationOverrides, LlmProvider, Message, StatusTx, ToolDefinition,
};
pub struct CompatibleProvider {
inner: OpenAiProvider,
provider_name: String,
}
impl CompatibleProvider {
#[must_use]
pub fn new(
provider_name: String,
api_key: String,
base_url: String,
model: String,
max_tokens: u32,
embedding_model: Option<String>,
) -> Self {
let inner =
OpenAiProvider::new(api_key, base_url, model, max_tokens, embedding_model, None);
Self {
inner,
provider_name,
}
}
}
impl fmt::Debug for CompatibleProvider {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("CompatibleProvider")
.field("provider_name", &self.provider_name)
.field("inner", &self.inner)
.finish_non_exhaustive()
}
}
impl Clone for CompatibleProvider {
fn clone(&self) -> Self {
Self {
inner: self.inner.clone(),
provider_name: self.provider_name.clone(),
}
}
}
impl CompatibleProvider {
pub async fn list_models_remote(
&self,
) -> Result<Vec<crate::model_cache::RemoteModelInfo>, LlmError> {
self.inner.list_models_remote().await
}
}
impl CompatibleProvider {
pub fn set_status_tx(&mut self, tx: StatusTx) {
self.inner.status_tx = Some(tx);
}
#[must_use]
pub fn with_generation_overrides(mut self, overrides: GenerationOverrides) -> Self {
self.inner = self.inner.with_generation_overrides(overrides);
self
}
}
impl LlmProvider for CompatibleProvider {
fn context_window(&self) -> Option<usize> {
None
}
#[cfg_attr(
feature = "profiling",
tracing::instrument(
name = "llm.chat",
skip_all,
fields(provider = self.name(), model = self.model_identifier())
)
)]
async fn chat(&self, messages: &[Message]) -> Result<String, LlmError> {
self.inner.chat(messages).await
}
#[cfg_attr(
feature = "profiling",
tracing::instrument(
name = "llm.chat_stream",
skip_all,
fields(provider = self.name(), model = self.model_identifier())
)
)]
async fn chat_stream(&self, messages: &[Message]) -> Result<ChatStream, LlmError> {
self.inner.chat_stream(messages).await
}
fn supports_streaming(&self) -> bool {
self.inner.supports_streaming()
}
#[cfg_attr(
feature = "profiling",
tracing::instrument(
name = "llm.embed",
skip_all,
fields(provider = self.name(), model = self.model_identifier())
)
)]
async fn embed(&self, text: &str) -> Result<Vec<f32>, LlmError> {
self.inner.embed(text).await
}
async fn embed_batch(&self, texts: &[&str]) -> Result<Vec<Vec<f32>>, LlmError> {
self.inner.embed_batch(texts).await
}
fn supports_embeddings(&self) -> bool {
self.inner.supports_embeddings()
}
fn name(&self) -> &str {
&self.provider_name
}
fn model_identifier(&self) -> &str {
self.inner.model_identifier()
}
fn list_models(&self) -> Vec<String> {
self.inner.list_models()
}
fn supports_structured_output(&self) -> bool {
self.inner.supports_structured_output()
}
async fn chat_typed<T>(&self, messages: &[Message]) -> Result<T, LlmError>
where
T: serde::de::DeserializeOwned + schemars::JsonSchema + 'static,
Self: Sized,
{
self.inner.chat_typed(messages).await
}
async fn chat_with_tools(
&self,
messages: &[Message],
tools: &[ToolDefinition],
) -> Result<ChatResponse, LlmError> {
self.inner.chat_with_tools(messages, tools).await
}
fn last_cache_usage(&self) -> Option<(u64, u64)> {
self.inner.last_cache_usage()
}
fn last_usage(&self) -> Option<(u64, u64)> {
self.inner.last_usage()
}
fn debug_request_json(
&self,
messages: &[Message],
tools: &[ToolDefinition],
stream: bool,
) -> serde_json::Value {
self.inner.debug_request_json(messages, tools, stream)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_provider() -> CompatibleProvider {
CompatibleProvider::new(
"groq".into(),
"key".into(),
"https://api.groq.com/openai/v1".into(),
"llama-3.3-70b".into(),
4096,
None,
)
}
#[test]
fn name_returns_custom_provider_name() {
let p = test_provider();
assert_eq!(p.name(), "groq");
}
#[test]
fn context_window_returns_none() {
assert!(test_provider().context_window().is_none());
}
#[test]
fn supports_streaming_delegates() {
assert!(test_provider().supports_streaming());
}
#[test]
fn supports_embeddings_without_model() {
assert!(!test_provider().supports_embeddings());
}
#[test]
fn supports_embeddings_with_model() {
let p = CompatibleProvider::new(
"test".into(),
"key".into(),
"http://localhost".into(),
"m".into(),
100,
Some("embed-model".into()),
);
assert!(p.supports_embeddings());
}
#[test]
fn clone_preserves_name() {
let p = test_provider();
let c = p.clone();
assert_eq!(c.name(), "groq");
}
#[test]
fn debug_contains_provider_name() {
let debug = format!("{:?}", test_provider());
assert!(debug.contains("groq"));
assert!(debug.contains("CompatibleProvider"));
}
#[tokio::test]
async fn chat_unreachable_errors() {
let p = CompatibleProvider::new(
"test".into(),
"key".into(),
"http://127.0.0.1:1".into(),
"m".into(),
100,
None,
);
let msgs = vec![Message::from_legacy(crate::provider::Role::User, "hello")];
assert!(p.chat(&msgs).await.is_err());
}
#[tokio::test]
async fn embed_without_model_errors() {
let p = test_provider();
let result = p.embed("test").await;
assert!(result.is_err());
}
#[test]
fn last_usage_initially_none() {
assert!(test_provider().last_usage().is_none());
}
}