serdes_ai/direct.rs
1//! Direct model request functions.
2//!
3//! These functions allow making imperative requests to models with minimal abstraction.
4//! The only abstraction is input/output schema translation for unified API access.
5//!
6//! Use these when you want simple, direct access to models without the full agent
7//! infrastructure. Great for one-off queries, scripts, and simple integrations.
8//!
9//! # Examples
10//!
11//! ## Non-streaming request
12//!
13//! ```rust,ignore
14//! use serdes_ai::direct::model_request;
15//! use serdes_ai_core::ModelRequest;
16//!
17//! let response = model_request(
18//! "openai:gpt-4o",
19//! &[ModelRequest::user("What is the capital of France?")],
20//! None,
21//! None,
22//! ).await?;
23//!
24//! println!("{}", response.text());
25//! ```
26//!
27//! ## Streaming request
28//!
29//! ```rust,ignore
30//! use serdes_ai::direct::model_request_stream;
31//! use futures::StreamExt;
32//!
33//! let mut stream = model_request_stream(
34//! "anthropic:claude-3-5-sonnet",
35//! &[ModelRequest::user("Write a poem")],
36//! None,
37//! None,
38//! ).await?;
39//!
40//! while let Some(event) = stream.next().await {
41//! // Handle streaming events
42//! }
43//! ```
44//!
45//! ## Using a pre-built model instance
46//!
47//! ```rust,ignore
48//! use serdes_ai::direct::model_request;
49//! use serdes_ai_models::openai::OpenAIChatModel;
50//!
51//! let model = OpenAIChatModel::from_env("gpt-4o")?;
52//! let response = model_request(
53//! model,
54//! &[ModelRequest::user("Hello!")],
55//! None,
56//! None,
57//! ).await?;
58//! ```
59
60use std::sync::Arc;
61
62use futures::StreamExt;
63use serdes_ai_core::{
64 messages::ModelResponseStreamEvent, ModelRequest, ModelResponse, ModelSettings,
65};
66use serdes_ai_models::{BoxedModel, Model, ModelError, ModelRequestParameters, StreamedResponse};
67use thiserror::Error;
68
69// ============================================================================
70// Error Type
71// ============================================================================
72
73/// Error type for direct requests.
74#[derive(Debug, Error)]
75pub enum DirectError {
76 /// Invalid model name format.
77 #[error("Invalid model name: {0}")]
78 InvalidModelName(String),
79
80 /// Model-level error (API, network, etc.).
81 #[error("Model error: {0}")]
82 ModelError(#[from] ModelError),
83
84 /// Runtime error (e.g., sync functions called in async context).
85 #[error("Runtime error: {0}")]
86 RuntimeError(String),
87
88 /// Provider not available (feature not enabled).
89 #[error("Provider not available: {0}. Enable the corresponding feature.")]
90 ProviderNotAvailable(String),
91}
92
93// ============================================================================
94// Model Specification
95// ============================================================================
96
97/// Model specification - either a string like "openai:gpt-4o" or a Model instance.
98///
99/// This allows flexible model specification in the direct API functions.
100///
101/// # Examples
102///
103/// ```rust,ignore
104/// // From string
105/// let spec: ModelSpec = "openai:gpt-4o".into();
106///
107/// // From model instance
108/// let model = OpenAIChatModel::from_env("gpt-4o")?;
109/// let spec: ModelSpec = model.into();
110/// ```
111#[derive(Clone)]
112pub enum ModelSpec {
113 /// Model specified by name (e.g., "openai:gpt-4o").
114 Name(String),
115 /// Pre-built model instance.
116 Instance(BoxedModel),
117}
118
119impl From<&str> for ModelSpec {
120 fn from(s: &str) -> Self {
121 ModelSpec::Name(s.to_string())
122 }
123}
124
125impl From<String> for ModelSpec {
126 fn from(s: String) -> Self {
127 ModelSpec::Name(s)
128 }
129}
130
131impl From<BoxedModel> for ModelSpec {
132 fn from(model: BoxedModel) -> Self {
133 ModelSpec::Instance(model)
134 }
135}
136
137impl ModelSpec {
138 /// Create a ModelSpec from any concrete Model type.
139 ///
140 /// This is a convenience method for wrapping concrete model types.
141 ///
142 /// # Example
143 ///
144 /// ```rust,ignore
145 /// use serdes_ai::direct::ModelSpec;
146 /// use serdes_ai_models::openai::OpenAIChatModel;
147 ///
148 /// let model = OpenAIChatModel::from_env("gpt-4o")?;
149 /// let spec = ModelSpec::from_model(model);
150 /// ```
151 pub fn from_model<M: Model + 'static>(model: M) -> Self {
152 ModelSpec::Instance(Arc::new(model))
153 }
154}
155
156impl ModelSpec {
157 /// Resolve the spec into a concrete model instance.
158 fn resolve(self) -> Result<BoxedModel, DirectError> {
159 match self {
160 ModelSpec::Name(name) => parse_model_name(&name),
161 ModelSpec::Instance(model) => Ok(model),
162 }
163 }
164}
165
166// ============================================================================
167// Non-Streaming Requests
168// ============================================================================
169
170/// Make a non-streamed request to a model.
171///
172/// This is the simplest way to get a response from a model. It blocks until
173/// the full response is available.
174///
175/// # Arguments
176///
177/// * `model` - Model specification (string like "openai:gpt-4o" or a Model instance)
178/// * `messages` - Slice of request messages
179/// * `model_settings` - Optional model settings (temperature, max_tokens, etc.)
180/// * `model_request_parameters` - Optional request parameters (tools, output schema, etc.)
181///
182/// # Example
183///
184/// ```rust,ignore
185/// use serdes_ai::direct::model_request;
186/// use serdes_ai_core::ModelRequest;
187///
188/// let response = model_request(
189/// "openai:gpt-4o",
190/// &[ModelRequest::user("What is the capital of France?")],
191/// None,
192/// None,
193/// ).await?;
194///
195/// println!("{}", response.text());
196/// ```
197pub async fn model_request(
198 model: impl Into<ModelSpec>,
199 messages: &[ModelRequest],
200 model_settings: Option<ModelSettings>,
201 model_request_parameters: Option<ModelRequestParameters>,
202) -> Result<ModelResponse, DirectError> {
203 let model = model.into().resolve()?;
204 let settings = model_settings.unwrap_or_default();
205 let params = model_request_parameters.unwrap_or_default();
206
207 let response = model.request(messages, &settings, ¶ms).await?;
208 Ok(response)
209}
210
211/// Make a synchronous (blocking) non-streamed request.
212///
213/// This wraps `model_request` with a tokio runtime. It creates a new runtime
214/// for each call, so it's not the most efficient for high-throughput scenarios.
215///
216/// # Warning
217///
218/// Cannot be used inside async code (will panic if called from an async context).
219/// Use `model_request` instead in async contexts.
220///
221/// # Example
222///
223/// ```rust,ignore
224/// use serdes_ai::direct::model_request_sync;
225/// use serdes_ai_core::ModelRequest;
226///
227/// fn main() {
228/// let response = model_request_sync(
229/// "openai:gpt-4o",
230/// &[ModelRequest::user("Hello!")],
231/// None,
232/// None,
233/// ).unwrap();
234///
235/// println!("{}", response.text());
236/// }
237/// ```
238pub fn model_request_sync(
239 model: impl Into<ModelSpec>,
240 messages: &[ModelRequest],
241 model_settings: Option<ModelSettings>,
242 model_request_parameters: Option<ModelRequestParameters>,
243) -> Result<ModelResponse, DirectError> {
244 // Check if we're already in an async context
245 if tokio::runtime::Handle::try_current().is_ok() {
246 return Err(DirectError::RuntimeError(
247 "model_request_sync cannot be called from async context. Use model_request instead."
248 .to_string(),
249 ));
250 }
251
252 // Create a new runtime for the blocking call
253 let rt = tokio::runtime::Builder::new_current_thread()
254 .enable_all()
255 .build()
256 .map_err(|e| DirectError::RuntimeError(format!("Failed to create runtime: {e}")))?;
257
258 // Clone what we need since we can't move references
259 let model_spec = model.into();
260 let messages_owned: Vec<ModelRequest> = messages.to_vec();
261 let settings = model_settings;
262 let params = model_request_parameters;
263
264 rt.block_on(async move { model_request(model_spec, &messages_owned, settings, params).await })
265}
266
267// ============================================================================
268// Streaming Requests
269// ============================================================================
270
271/// Make a streaming request to a model.
272///
273/// Returns a stream of response events that can be processed as they arrive.
274/// This is useful for real-time output and long responses.
275///
276/// # Arguments
277///
278/// * `model` - Model specification (string like "openai:gpt-4o" or a Model instance)
279/// * `messages` - Slice of request messages
280/// * `model_settings` - Optional model settings (temperature, max_tokens, etc.)
281/// * `model_request_parameters` - Optional request parameters (tools, output schema, etc.)
282///
283/// # Example
284///
285/// ```rust,ignore
286/// use serdes_ai::direct::model_request_stream;
287/// use serdes_ai_core::messages::ModelResponseStreamEvent;
288/// use futures::StreamExt;
289///
290/// let mut stream = model_request_stream(
291/// "anthropic:claude-3-5-sonnet",
292/// &[ModelRequest::user("Write a poem about Rust")],
293/// None,
294/// None,
295/// ).await?;
296///
297/// while let Some(event) = stream.next().await {
298/// match event? {
299/// ModelResponseStreamEvent::PartDelta(delta) => {
300/// if let Some(text) = delta.delta.content_delta() {
301/// print!("{}", text);
302/// }
303/// }
304/// _ => {}
305/// }
306/// }
307/// ```
308pub async fn model_request_stream(
309 model: impl Into<ModelSpec>,
310 messages: &[ModelRequest],
311 model_settings: Option<ModelSettings>,
312 model_request_parameters: Option<ModelRequestParameters>,
313) -> Result<StreamedResponse, DirectError> {
314 let model = model.into().resolve()?;
315 let settings = model_settings.unwrap_or_default();
316 let params = model_request_parameters.unwrap_or_default();
317
318 let stream = model.request_stream(messages, &settings, ¶ms).await?;
319 Ok(stream)
320}
321
322/// Synchronous streaming request wrapper.
323///
324/// This struct wraps a streaming response and provides a synchronous iterator
325/// interface for consuming streaming events.
326///
327/// # Warning
328///
329/// Cannot be used inside async code (will panic if called from an async context).
330pub struct StreamedResponseSync {
331 /// The underlying async runtime.
332 runtime: tokio::runtime::Runtime,
333 /// The underlying async stream.
334 stream: Option<StreamedResponse>,
335}
336
337impl StreamedResponseSync {
338 /// Create a new sync wrapper around an async stream.
339 fn new(stream: StreamedResponse) -> Result<Self, DirectError> {
340 let runtime = tokio::runtime::Builder::new_current_thread()
341 .enable_all()
342 .build()
343 .map_err(|e| DirectError::RuntimeError(format!("Failed to create runtime: {e}")))?;
344
345 Ok(Self {
346 runtime,
347 stream: Some(stream),
348 })
349 }
350}
351
352impl Iterator for StreamedResponseSync {
353 type Item = Result<ModelResponseStreamEvent, ModelError>;
354
355 fn next(&mut self) -> Option<Self::Item> {
356 let stream = self.stream.as_mut()?;
357 self.runtime.block_on(stream.next())
358 }
359}
360
361/// Synchronous streaming request.
362///
363/// This creates a streaming request and wraps it in a synchronous iterator.
364///
365/// # Warning
366///
367/// Cannot be used inside async code (will panic if called from an async context).
368/// Use `model_request_stream` instead in async contexts.
369///
370/// # Example
371///
372/// ```rust,ignore
373/// use serdes_ai::direct::model_request_stream_sync;
374/// use serdes_ai_core::ModelRequest;
375///
376/// fn main() {
377/// let stream = model_request_stream_sync(
378/// "openai:gpt-4o",
379/// &[ModelRequest::user("Tell me a story")],
380/// None,
381/// None,
382/// ).unwrap();
383///
384/// for event in stream {
385/// // Handle each event
386/// }
387/// }
388/// ```
389pub fn model_request_stream_sync(
390 model: impl Into<ModelSpec>,
391 messages: &[ModelRequest],
392 model_settings: Option<ModelSettings>,
393 model_request_parameters: Option<ModelRequestParameters>,
394) -> Result<StreamedResponseSync, DirectError> {
395 // Check if we're already in an async context
396 if tokio::runtime::Handle::try_current().is_ok() {
397 return Err(DirectError::RuntimeError(
398 "model_request_stream_sync cannot be called from async context. Use model_request_stream instead."
399 .to_string(),
400 ));
401 }
402
403 // Create a runtime to set up the stream
404 let setup_rt = tokio::runtime::Builder::new_current_thread()
405 .enable_all()
406 .build()
407 .map_err(|e| DirectError::RuntimeError(format!("Failed to create runtime: {e}")))?;
408
409 let model_spec = model.into();
410 let messages_owned: Vec<ModelRequest> = messages.to_vec();
411 let settings = model_settings;
412 let params = model_request_parameters;
413
414 let stream = setup_rt.block_on(async move {
415 model_request_stream(model_spec, &messages_owned, settings, params).await
416 })?;
417
418 // Drop the setup runtime and create the iterator with its own runtime
419 drop(setup_rt);
420
421 StreamedResponseSync::new(stream)
422}
423
424// ============================================================================
425// Model Parsing
426// ============================================================================
427
428/// Parse a model name like "openai:gpt-4o" into a model instance.
429///
430/// Supported formats:
431/// - `provider:model_name` (e.g., "openai:gpt-4o", "anthropic:claude-3-5-sonnet")
432/// - `model_name` (defaults to OpenAI)
433///
434/// Available providers (when their features are enabled):
435/// - `openai` / `gpt`: OpenAI models
436/// - `anthropic` / `claude`: Anthropic Claude models
437/// - `groq`: Groq fast inference
438/// - `mistral`: Mistral AI models
439/// - `ollama`: Local Ollama models
440/// - `bedrock` / `aws`: AWS Bedrock models
441/// - `openrouter` / `or`: OpenRouter multi-provider
442/// - `huggingface` / `hf`: HuggingFace Inference API
443/// - `cohere` / `co`: Cohere models
444fn parse_model_name(name: &str) -> Result<BoxedModel, DirectError> {
445 // Use the infer_model function from serdes-ai-models
446 #[cfg(feature = "openai")]
447 {
448 serdes_ai_models::infer_model(name).map_err(DirectError::ModelError)
449 }
450
451 #[cfg(not(feature = "openai"))]
452 {
453 // Without openai feature, we need manual parsing
454 let (provider, model_name) = if name.contains(':') {
455 let parts: Vec<&str> = name.splitn(2, ':').collect();
456 (parts[0], parts[1])
457 } else {
458 return Err(DirectError::InvalidModelName(format!(
459 "Model name '{}' requires a provider prefix (e.g., 'anthropic:{}') \
460 when the 'openai' feature is not enabled.",
461 name, name
462 )));
463 };
464
465 match provider {
466 #[cfg(feature = "anthropic")]
467 "anthropic" | "claude" => {
468 let model = serdes_ai_models::AnthropicModel::from_env(model_name)
469 .map_err(DirectError::ModelError)?;
470 Ok(Arc::new(model))
471 }
472 #[cfg(feature = "groq")]
473 "groq" => {
474 let model = serdes_ai_models::GroqModel::from_env(model_name)
475 .map_err(DirectError::ModelError)?;
476 Ok(Arc::new(model))
477 }
478 #[cfg(feature = "mistral")]
479 "mistral" => {
480 let model = serdes_ai_models::MistralModel::from_env(model_name)
481 .map_err(DirectError::ModelError)?;
482 Ok(Arc::new(model))
483 }
484 #[cfg(feature = "ollama")]
485 "ollama" => {
486 let model = serdes_ai_models::OllamaModel::from_env(model_name)
487 .map_err(DirectError::ModelError)?;
488 Ok(Arc::new(model))
489 }
490 #[cfg(feature = "bedrock")]
491 "bedrock" | "aws" => {
492 let model = serdes_ai_models::BedrockModel::new(model_name)
493 .map_err(DirectError::ModelError)?;
494 Ok(Arc::new(model))
495 }
496 _ => Err(DirectError::ProviderNotAvailable(provider.to_string())),
497 }
498 }
499}
500
501// ============================================================================
502// Tests
503// ============================================================================
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508
509 #[test]
510 fn test_model_spec_from_str() {
511 let spec: ModelSpec = "openai:gpt-4o".into();
512 assert!(matches!(spec, ModelSpec::Name(ref s) if s == "openai:gpt-4o"));
513 }
514
515 #[test]
516 fn test_model_spec_from_string() {
517 let spec: ModelSpec = String::from("anthropic:claude-3").into();
518 assert!(matches!(spec, ModelSpec::Name(ref s) if s == "anthropic:claude-3"));
519 }
520
521 #[test]
522 fn test_direct_error_display() {
523 let err = DirectError::InvalidModelName("bad-model".to_string());
524 assert!(err.to_string().contains("bad-model"));
525
526 let err = DirectError::ProviderNotAvailable("unknown".to_string());
527 assert!(err.to_string().contains("unknown"));
528
529 let err = DirectError::RuntimeError("something went wrong".to_string());
530 assert!(err.to_string().contains("something went wrong"));
531 }
532
533 #[test]
534 fn test_sync_runtime_detection() {
535 // In a normal sync context, this should not error due to runtime detection
536 // (but might fail due to missing API keys)
537 // We're just testing the runtime detection logic here
538
539 // Can't easily test the async context detection without actually being in one
540 }
541}