use std::sync::Arc;
use futures::StreamExt;
use serdes_ai_core::{
messages::ModelResponseStreamEvent, ModelRequest, ModelResponse, ModelSettings,
};
use serdes_ai_models::{BoxedModel, Model, ModelError, ModelRequestParameters, StreamedResponse};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum DirectError {
#[error("Invalid model name: {0}")]
InvalidModelName(String),
#[error("Model error: {0}")]
ModelError(#[from] ModelError),
#[error("Runtime error: {0}")]
RuntimeError(String),
#[error("Provider not available: {0}. Enable the corresponding feature.")]
ProviderNotAvailable(String),
}
#[derive(Clone)]
pub enum ModelSpec {
Name(String),
Instance(BoxedModel),
}
impl From<&str> for ModelSpec {
fn from(s: &str) -> Self {
ModelSpec::Name(s.to_string())
}
}
impl From<String> for ModelSpec {
fn from(s: String) -> Self {
ModelSpec::Name(s)
}
}
impl From<BoxedModel> for ModelSpec {
fn from(model: BoxedModel) -> Self {
ModelSpec::Instance(model)
}
}
impl ModelSpec {
pub fn from_model<M: Model + 'static>(model: M) -> Self {
ModelSpec::Instance(Arc::new(model))
}
}
impl ModelSpec {
fn resolve(self) -> Result<BoxedModel, DirectError> {
match self {
ModelSpec::Name(name) => parse_model_name(&name),
ModelSpec::Instance(model) => Ok(model),
}
}
}
pub async fn model_request(
model: impl Into<ModelSpec>,
messages: &[ModelRequest],
model_settings: Option<ModelSettings>,
model_request_parameters: Option<ModelRequestParameters>,
) -> Result<ModelResponse, DirectError> {
let model = model.into().resolve()?;
let settings = model_settings.unwrap_or_default();
let params = model_request_parameters.unwrap_or_default();
let response = model.request(messages, &settings, ¶ms).await?;
Ok(response)
}
pub fn model_request_sync(
model: impl Into<ModelSpec>,
messages: &[ModelRequest],
model_settings: Option<ModelSettings>,
model_request_parameters: Option<ModelRequestParameters>,
) -> Result<ModelResponse, DirectError> {
if tokio::runtime::Handle::try_current().is_ok() {
return Err(DirectError::RuntimeError(
"model_request_sync cannot be called from async context. Use model_request instead."
.to_string(),
));
}
let rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| DirectError::RuntimeError(format!("Failed to create runtime: {e}")))?;
let model_spec = model.into();
let messages_owned: Vec<ModelRequest> = messages.to_vec();
let settings = model_settings;
let params = model_request_parameters;
rt.block_on(async move { model_request(model_spec, &messages_owned, settings, params).await })
}
pub async fn model_request_stream(
model: impl Into<ModelSpec>,
messages: &[ModelRequest],
model_settings: Option<ModelSettings>,
model_request_parameters: Option<ModelRequestParameters>,
) -> Result<StreamedResponse, DirectError> {
let model = model.into().resolve()?;
let settings = model_settings.unwrap_or_default();
let params = model_request_parameters.unwrap_or_default();
let stream = model.request_stream(messages, &settings, ¶ms).await?;
Ok(stream)
}
pub struct StreamedResponseSync {
runtime: tokio::runtime::Runtime,
stream: Option<StreamedResponse>,
}
impl StreamedResponseSync {
fn new(stream: StreamedResponse) -> Result<Self, DirectError> {
let runtime = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| DirectError::RuntimeError(format!("Failed to create runtime: {e}")))?;
Ok(Self {
runtime,
stream: Some(stream),
})
}
}
impl Iterator for StreamedResponseSync {
type Item = Result<ModelResponseStreamEvent, ModelError>;
fn next(&mut self) -> Option<Self::Item> {
let stream = self.stream.as_mut()?;
self.runtime.block_on(stream.next())
}
}
pub fn model_request_stream_sync(
model: impl Into<ModelSpec>,
messages: &[ModelRequest],
model_settings: Option<ModelSettings>,
model_request_parameters: Option<ModelRequestParameters>,
) -> Result<StreamedResponseSync, DirectError> {
if tokio::runtime::Handle::try_current().is_ok() {
return Err(DirectError::RuntimeError(
"model_request_stream_sync cannot be called from async context. Use model_request_stream instead."
.to_string(),
));
}
let setup_rt = tokio::runtime::Builder::new_current_thread()
.enable_all()
.build()
.map_err(|e| DirectError::RuntimeError(format!("Failed to create runtime: {e}")))?;
let model_spec = model.into();
let messages_owned: Vec<ModelRequest> = messages.to_vec();
let settings = model_settings;
let params = model_request_parameters;
let stream = setup_rt.block_on(async move {
model_request_stream(model_spec, &messages_owned, settings, params).await
})?;
drop(setup_rt);
StreamedResponseSync::new(stream)
}
fn parse_model_name(name: &str) -> Result<BoxedModel, DirectError> {
#[cfg(feature = "openai")]
{
serdes_ai_models::infer_model(name).map_err(DirectError::ModelError)
}
#[cfg(not(feature = "openai"))]
{
let (provider, model_name) = if name.contains(':') {
let parts: Vec<&str> = name.splitn(2, ':').collect();
(parts[0], parts[1])
} else {
return Err(DirectError::InvalidModelName(format!(
"Model name '{}' requires a provider prefix (e.g., 'anthropic:{}') \
when the 'openai' feature is not enabled.",
name, name
)));
};
match provider {
#[cfg(feature = "anthropic")]
"anthropic" | "claude" => {
let model = serdes_ai_models::AnthropicModel::from_env(model_name)
.map_err(DirectError::ModelError)?;
Ok(Arc::new(model))
}
#[cfg(feature = "groq")]
"groq" => {
let model = serdes_ai_models::GroqModel::from_env(model_name)
.map_err(DirectError::ModelError)?;
Ok(Arc::new(model))
}
#[cfg(feature = "mistral")]
"mistral" => {
let model = serdes_ai_models::MistralModel::from_env(model_name)
.map_err(DirectError::ModelError)?;
Ok(Arc::new(model))
}
#[cfg(feature = "ollama")]
"ollama" => {
let model = serdes_ai_models::OllamaModel::from_env(model_name)
.map_err(DirectError::ModelError)?;
Ok(Arc::new(model))
}
#[cfg(feature = "bedrock")]
"bedrock" | "aws" => {
let model = serdes_ai_models::BedrockModel::new(model_name)
.map_err(DirectError::ModelError)?;
Ok(Arc::new(model))
}
_ => Err(DirectError::ProviderNotAvailable(provider.to_string())),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_model_spec_from_str() {
let spec: ModelSpec = "openai:gpt-4o".into();
assert!(matches!(spec, ModelSpec::Name(ref s) if s == "openai:gpt-4o"));
}
#[test]
fn test_model_spec_from_string() {
let spec: ModelSpec = String::from("anthropic:claude-3").into();
assert!(matches!(spec, ModelSpec::Name(ref s) if s == "anthropic:claude-3"));
}
#[test]
fn test_direct_error_display() {
let err = DirectError::InvalidModelName("bad-model".to_string());
assert!(err.to_string().contains("bad-model"));
let err = DirectError::ProviderNotAvailable("unknown".to_string());
assert!(err.to_string().contains("unknown"));
let err = DirectError::RuntimeError("something went wrong".to_string());
assert!(err.to_string().contains("something went wrong"));
}
#[test]
fn test_sync_runtime_detection() {
}
}