ai_sdk_provider/language_model/
trait_def.rs

1use super::*;
2use crate::{Error, Result, SharedHeaders, SharedProviderMetadata};
3use async_trait::async_trait;
4use std::collections::HashMap;
5use std::future::{Future, IntoFuture};
6use std::pin::Pin;
7use std::sync::Arc;
8use tokio_stream::Stream;
9
10macro_rules! impl_call_options {
11    () => {
12        /// Sets the maximum number of tokens to generate in the model output.
13        ///
14        /// This parameter controls the maximum length of the generated response. The actual
15        /// number of tokens generated may be less if the model reaches a natural stopping point
16        /// or encounters a stop sequence.
17        pub fn max_tokens(mut self, max_tokens: u32) -> Self {
18            self.options.max_output_tokens = Some(max_tokens);
19            self
20        }
21
22        /// Sets the sampling temperature for controlling output randomness.
23        ///
24        /// Valid range: 0.0 to 2.0
25        /// - Lower values (e.g., 0.2): More focused, deterministic outputs
26        /// - Higher values (e.g., 1.5): More creative, diverse outputs
27        /// - Default: Provider-specific (typically 1.0)
28        pub fn temperature(mut self, temperature: f32) -> Self {
29            self.options.temperature = Some(temperature);
30            self
31        }
32
33        /// Sets sequences that will cause generation to stop when encountered.
34        ///
35        /// When the model generates any of these sequences, generation will terminate
36        /// immediately. Useful for structured output or preventing unwanted continuations.
37        pub fn stop_sequences(mut self, sequences: Vec<String>) -> Self {
38            self.options.stop_sequences = Some(sequences);
39            self
40        }
41
42        /// Sets the nucleus sampling probability mass.
43        ///
44        /// Valid range: 0.0 to 1.0
45        /// Only tokens whose cumulative probability exceeds this threshold are considered
46        /// for sampling. Lower values produce more focused outputs by restricting the
47        /// token pool to high-probability candidates.
48        pub fn top_p(mut self, top_p: f32) -> Self {
49            self.options.top_p = Some(top_p);
50            self
51        }
52
53        /// Sets the top-k sampling parameter.
54        ///
55        /// Limits sampling to the k most probable next tokens. Reduces output diversity
56        /// by eliminating low-probability tokens from consideration. Provider support varies.
57        pub fn top_k(mut self, top_k: u32) -> Self {
58            self.options.top_k = Some(top_k);
59            self
60        }
61
62        /// Sets the presence penalty for reducing topic repetition.
63        ///
64        /// Valid range: -2.0 to 2.0
65        /// Positive values penalize tokens that have already appeared in the generated text,
66        /// encouraging the model to explore new topics. Negative values have the opposite effect.
67        pub fn presence_penalty(mut self, penalty: f32) -> Self {
68            self.options.presence_penalty = Some(penalty);
69            self
70        }
71
72        /// Sets the frequency penalty for reducing token repetition.
73        ///
74        /// Valid range: -2.0 to 2.0
75        /// Positive values penalize tokens based on their frequency in the generated text,
76        /// with stronger penalties for frequently occurring tokens. This reduces repetitive phrasing.
77        pub fn frequency_penalty(mut self, penalty: f32) -> Self {
78            self.options.frequency_penalty = Some(penalty);
79            self
80        }
81
82        /// Sets a random seed for deterministic generation.
83        ///
84        /// Using the same seed with identical parameters should produce consistent outputs
85        /// across multiple requests. Useful for reproducible testing and debugging.
86        /// Provider support for deterministic generation varies.
87        pub fn seed(mut self, seed: i64) -> Self {
88            self.options.seed = Some(seed);
89            self
90        }
91
92        /// Sets the tools available for the model to call during generation.
93        ///
94        /// Tools enable function calling where the model can request execution of
95        /// external functions with structured arguments. The application is responsible
96        /// for executing tool calls and providing results back to the model.
97        pub fn tools(mut self, tools: Vec<Tool>) -> Self {
98            self.options.tools = Some(tools);
99            self
100        }
101
102        /// Sets the tool selection strategy.
103        ///
104        /// Controls how the model decides whether and which tools to call:
105        /// - `Auto`: Model autonomously decides when to use tools
106        /// - `Required`: Model must call at least one tool
107        /// - `None`: Tools are disabled for this request
108        /// - `Tool { name }`: Model must call the specified tool
109        pub fn tool_choice(mut self, choice: ToolChoice) -> Self {
110            self.options.tool_choice = Some(choice);
111            self
112        }
113
114        /// Forces a specific output format for the generated response.
115        ///
116        /// Constrains the model to produce output in the specified format:
117        /// - `ResponseFormat::Text`: Plain text output (default)
118        /// - `ResponseFormat::Json`: Valid JSON object or array
119        ///
120        /// JSON format is useful for structured data extraction and schema-compliant outputs.
121        pub fn response_format(mut self, format: ResponseFormat) -> Self {
122            self.options.response_format = Some(format);
123            self
124        }
125
126        /// Sets custom HTTP headers for the provider request.
127        ///
128        /// Replaces any existing headers. Use this for provider-specific headers such as
129        /// organization identifiers, API version selection, or custom authentication schemes.
130        pub fn headers(mut self, headers: HashMap<String, String>) -> Self {
131            self.options.headers = Some(headers);
132            self
133        }
134
135        /// Adds a single HTTP header to the provider request.
136        ///
137        /// Merges with existing headers rather than replacing them. Useful for
138        /// incrementally building header sets in a builder pattern.
139        pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
140            let headers = self.options.headers.get_or_insert_with(HashMap::new);
141            headers.insert(key.into(), value.into());
142            self
143        }
144    };
145}
146
147/// Builder for configuring and executing non-streaming generation requests.
148///
149/// This builder provides a fluent interface for setting generation parameters before
150/// executing the request. It supports both explicit execution via `.send()` and
151/// implicit execution by awaiting the builder directly.
152#[derive(Clone)]
153pub struct GenerateBuilder<'a, M: LanguageModel + ?Sized> {
154    model: &'a M,
155    options: CallOptions,
156    init_error: Option<Arc<Error>>,
157}
158
159impl<'a, M: LanguageModel + ?Sized> GenerateBuilder<'a, M> {
160    /// Creates a new generation builder with the specified prompt.
161    ///
162    /// If the prompt conversion fails, the error is captured and will be returned
163    /// when the builder is executed.
164    pub fn new(model: &'a M, prompt_result: std::result::Result<Prompt, Error>) -> Self {
165        match prompt_result {
166            Ok(prompt) => Self {
167                model,
168                options: CallOptions {
169                    prompt,
170                    ..Default::default()
171                },
172                init_error: None,
173            },
174            Err(e) => Self {
175                model,
176                options: CallOptions::default(),
177                init_error: Some(Arc::new(e)),
178            },
179        }
180    }
181
182    impl_call_options!();
183
184    /// Executes the generation request and returns the complete response.
185    ///
186    /// This method consumes the builder and performs the actual API call to the
187    /// language model provider. Use this for explicit execution, or simply await
188    /// the builder directly for implicit execution.
189    pub async fn send(self) -> Result<GenerateResponse> {
190        if let Some(err) = self.init_error {
191            // We have to return the error. Since we can't move out of Arc,
192            // and Error is a Box<dyn std::error::Error>, we can recreate it
193            // from the string representation if we can't clone it.
194            // Or better, we just return a new generic error wrapping the old one.
195            return Err(format!("Initialization error: {}", err).into());
196        }
197        self.model.do_generate(self.options).await
198    }
199}
200
201/// Enables awaiting the builder directly without calling `.send()`.
202///
203/// This implementation allows the syntax `model.generate(prompt).await` as a
204/// shorthand for `model.generate(prompt).send().await`.
205impl<'a, M: LanguageModel + ?Sized> IntoFuture for GenerateBuilder<'a, M> {
206    type Output = Result<GenerateResponse>;
207    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
208
209    fn into_future(self) -> Self::IntoFuture {
210        if let Some(err) = self.init_error {
211            return Box::pin(async move { Err(format!("Initialization error: {}", err).into()) });
212        }
213        let model = self.model;
214        let options = self.options;
215        Box::pin(async move { model.do_generate(options).await })
216    }
217}
218
219/// Builder for configuring and executing streaming generation requests.
220///
221/// This builder provides a fluent interface for setting generation parameters before
222/// initiating a streaming response. Streaming allows processing partial responses as
223/// they arrive, enabling real-time display and reduced time-to-first-token.
224#[derive(Clone)]
225pub struct StreamBuilder<'a, M: LanguageModel + ?Sized> {
226    model: &'a M,
227    options: CallOptions,
228    init_error: Option<Arc<Error>>,
229}
230
231impl<'a, M: LanguageModel + ?Sized> StreamBuilder<'a, M> {
232    /// Creates a new streaming builder with the specified prompt.
233    ///
234    /// If the prompt conversion fails, the error is captured and will be returned
235    /// when the builder is executed.
236    pub fn new(model: &'a M, prompt_result: std::result::Result<Prompt, Error>) -> Self {
237        match prompt_result {
238            Ok(prompt) => Self {
239                model,
240                options: CallOptions {
241                    prompt,
242                    ..Default::default()
243                },
244                init_error: None,
245            },
246            Err(e) => Self {
247                model,
248                options: CallOptions::default(),
249                init_error: Some(Arc::new(e)),
250            },
251        }
252    }
253
254    impl_call_options!();
255
256    /// Initiates the streaming request and returns a stream of response parts.
257    ///
258    /// This method consumes the builder and establishes the streaming connection.
259    /// Use this for explicit execution, or simply await the builder directly.
260    pub async fn send(self) -> Result<StreamResponse> {
261        if let Some(err) = self.init_error {
262            return Err(format!("Initialization error: {}", err).into());
263        }
264        self.model.do_stream(self.options).await
265    }
266}
267
268/// Enables awaiting the builder directly without calling `.send()`.
269///
270/// This implementation allows the syntax `model.stream(prompt).await` as a
271/// shorthand for `model.stream(prompt).send().await`.
272impl<'a, M: LanguageModel + ?Sized> IntoFuture for StreamBuilder<'a, M> {
273    type Output = Result<StreamResponse>;
274    type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
275
276    fn into_future(self) -> Self::IntoFuture {
277        if let Some(err) = self.init_error {
278            return Box::pin(async move { Err(format!("Initialization error: {}", err).into()) });
279        }
280        let model = self.model;
281        let options = self.options;
282        Box::pin(async move { model.do_stream(options).await })
283    }
284}
285
286/// Core trait for language model providers.
287///
288/// This trait defines the standard interface that all language model implementations
289/// must satisfy. It provides both high-level builder methods for ergonomic usage and
290/// low-level `do_*` methods for direct provider integration.
291///
292/// # Implementation Requirements
293///
294/// Implementors must provide:
295/// - `provider()`: Provider identifier (e.g., "openai", "anthropic")
296/// - `model_id()`: Model identifier (e.g., "gpt-4", "claude-3-opus")
297/// - `do_generate()`: Non-streaming generation implementation
298/// - `do_stream()`: Streaming generation implementation
299///
300/// # Example Implementation
301///
302/// ```rust,ignore
303/// use ai_sdk_provider::{LanguageModel, CallOptions, GenerateResponse};
304/// use async_trait::async_trait;
305///
306/// struct MyModel {
307///     api_key: String,
308/// }
309///
310/// #[async_trait]
311/// impl LanguageModel for MyModel {
312///     fn provider(&self) -> &str { "my-provider" }
313///     fn model_id(&self) -> &str { "my-model-v1" }
314///
315///     async fn do_generate(&self, options: CallOptions) -> Result<GenerateResponse> {
316///         // Implementation
317///     }
318///
319///     async fn do_stream(&self, options: CallOptions) -> Result<StreamResponse> {
320///         // Implementation
321///     }
322/// }
323/// ```
324#[async_trait]
325pub trait LanguageModel: Send + Sync {
326    /// Returns the specification version implemented by this model.
327    ///
328    /// The default implementation returns "v3". Override this only if implementing
329    /// a different specification version.
330    fn specification_version(&self) -> &str {
331        "v3"
332    }
333
334    /// Returns the provider identifier for this model.
335    ///
336    /// This should be a stable, lowercase identifier for the AI provider
337    /// (e.g., "openai", "anthropic", "cohere").
338    fn provider(&self) -> &str;
339
340    /// Returns the specific model identifier.
341    ///
342    /// This should be the exact model name as recognized by the provider
343    /// (e.g., "gpt-4-turbo", "claude-3-opus-20240229").
344    fn model_id(&self) -> &str;
345
346    /// Returns URLs supported by this model for various operations.
347    ///
348    /// The default implementation returns an empty map. Providers can override this
349    /// to expose endpoint URLs for debugging or advanced configuration.
350    async fn supported_urls(&self) -> HashMap<String, Vec<String>> {
351        HashMap::new()
352    }
353
354    /// Creates a builder for a non-streaming generation request.
355    ///
356    /// This high-level method provides a fluent interface for configuring and executing
357    /// generation requests. The builder supports both explicit execution via `.send()`
358    /// and implicit execution by awaiting the builder directly.
359    ///
360    /// # Example
361    ///
362    /// ```rust,ignore
363    /// let response = model.generate("Explain photosynthesis")
364    ///     .temperature(0.7)
365    ///     .max_tokens(500)
366    ///     .await?;
367    /// println!("{}", response.text);
368    /// ```
369    fn generate<P>(&self, prompt: P) -> GenerateBuilder<'_, Self>
370    where
371        Self: Sized,
372        P: TryInto<Prompt>,
373        Error: From<P::Error>,
374    {
375        let result = prompt.try_into().map_err(Error::from);
376        GenerateBuilder::new(self, result)
377    }
378
379    /// Creates a builder for a streaming generation request.
380    ///
381    /// Streaming enables processing partial responses as they arrive from the provider,
382    /// allowing for real-time display and reduced latency to first token. The stream
383    /// yields `StreamPart` items containing text deltas, tool calls, and metadata.
384    ///
385    /// # Example
386    ///
387    /// ```rust,ignore
388    /// use tokio_stream::StreamExt;
389    ///
390    /// let mut stream = model.stream("Write a poem")
391    ///     .max_tokens(500)
392    ///     .await?;
393    ///
394    /// while let Some(part) = stream.next().await {
395    ///     match part? {
396    ///         StreamPart::TextDelta(delta) => print!("{}", delta),
397    ///         _ => {}
398    ///     }
399    /// }
400    /// ```
401    fn stream<P>(&self, prompt: P) -> StreamBuilder<'_, Self>
402    where
403        Self: Sized,
404        P: TryInto<Prompt>,
405        Error: From<P::Error>,
406    {
407        let result = prompt.try_into().map_err(Error::from);
408        StreamBuilder::new(self, result)
409    }
410
411    /// Executes a non-streaming generation request.
412    ///
413    /// This method must be implemented by providers to perform the actual API call.
414    /// It receives fully configured `CallOptions` and returns a complete `GenerateResponse`
415    /// containing all generated content, usage statistics, and metadata.
416    ///
417    /// Most users should prefer the high-level `generate()` method over calling this directly.
418    async fn do_generate(&self, options: CallOptions) -> Result<GenerateResponse>;
419
420    /// Executes a streaming generation request.
421    ///
422    /// This method must be implemented by providers to establish a streaming connection.
423    /// It receives fully configured `CallOptions` and returns a `StreamResponse` containing
424    /// an async stream of response parts.
425    ///
426    /// Most users should prefer the high-level `stream()` method over calling this directly.
427    async fn do_stream(&self, options: CallOptions) -> Result<StreamResponse>;
428}
429
430#[async_trait]
431impl<T: LanguageModel + ?Sized> LanguageModel for Box<T> {
432    fn specification_version(&self) -> &str {
433        (**self).specification_version()
434    }
435
436    fn provider(&self) -> &str {
437        (**self).provider()
438    }
439    fn model_id(&self) -> &str {
440        (**self).model_id()
441    }
442
443    async fn supported_urls(&self) -> HashMap<String, Vec<String>> {
444        (**self).supported_urls().await
445    }
446
447    async fn do_generate(&self, options: CallOptions) -> Result<GenerateResponse> {
448        (**self).do_generate(options).await
449    }
450
451    async fn do_stream(&self, options: CallOptions) -> Result<StreamResponse> {
452        (**self).do_stream(options).await
453    }
454}
455
456#[async_trait]
457impl<T: LanguageModel + ?Sized> LanguageModel for std::sync::Arc<T> {
458    fn specification_version(&self) -> &str {
459        (**self).specification_version()
460    }
461
462    fn provider(&self) -> &str {
463        (**self).provider()
464    }
465    fn model_id(&self) -> &str {
466        (**self).model_id()
467    }
468
469    async fn supported_urls(&self) -> HashMap<String, Vec<String>> {
470        (**self).supported_urls().await
471    }
472
473    async fn do_generate(&self, options: CallOptions) -> Result<GenerateResponse> {
474        (**self).do_generate(options).await
475    }
476
477    async fn do_stream(&self, options: CallOptions) -> Result<StreamResponse> {
478        (**self).do_stream(options).await
479    }
480}
481
482#[async_trait]
483impl<T: LanguageModel + ?Sized> LanguageModel for &T {
484    fn specification_version(&self) -> &str {
485        (**self).specification_version()
486    }
487
488    fn provider(&self) -> &str {
489        (**self).provider()
490    }
491    fn model_id(&self) -> &str {
492        (**self).model_id()
493    }
494
495    async fn supported_urls(&self) -> HashMap<String, Vec<String>> {
496        (**self).supported_urls().await
497    }
498
499    async fn do_generate(&self, options: CallOptions) -> Result<GenerateResponse> {
500        (**self).do_generate(options).await
501    }
502
503    async fn do_stream(&self, options: CallOptions) -> Result<StreamResponse> {
504        (**self).do_stream(options).await
505    }
506}
507
508/// Complete response from a non-streaming generation request.
509///
510/// Contains all generated content, usage statistics, and provider metadata
511/// from a single generation request. The response includes information about
512/// why generation stopped and any warnings from the provider.
513#[derive(Debug, Clone)]
514pub struct GenerateResponse {
515    /// Generated content parts (text, tool calls, etc.).
516    pub content: Vec<Content>,
517
518    /// Reason why generation terminated.
519    ///
520    /// Indicates whether generation stopped naturally, reached token limits,
521    /// encountered a stop sequence, or was interrupted for other reasons.
522    pub finish_reason: FinishReason,
523
524    /// Token usage statistics for this generation.
525    ///
526    /// Includes prompt tokens, completion tokens, and total tokens consumed.
527    /// Used for cost tracking and quota management.
528    pub usage: Usage,
529
530    /// Provider-specific metadata and additional information.
531    ///
532    /// May include provider-specific fields not covered by the standard response.
533    pub provider_metadata: Option<SharedProviderMetadata>,
534
535    /// Information about the request sent to the provider.
536    ///
537    /// Useful for debugging and auditing. May include the raw request body.
538    pub request: Option<RequestInfo>,
539
540    /// Information about the response from the provider.
541    ///
542    /// Contains response headers, raw body, provider-assigned IDs, and timestamps.
543    pub response: Option<ResponseInfo>,
544
545    /// Warnings issued by the provider during generation.
546    ///
547    /// Non-fatal issues such as unsupported parameters or deprecated features.
548    /// The request proceeds despite these warnings.
549    pub warnings: Vec<CallWarning>,
550}
551
552/// Response from a streaming generation request.
553///
554/// Contains an async stream that yields response parts as they arrive from the
555/// provider. Each item in the stream represents a chunk of generated content,
556/// tool calls, or metadata updates.
557pub struct StreamResponse {
558    /// Async stream of response parts.
559    ///
560    /// Yields `StreamPart` items containing text deltas, tool calls, usage updates,
561    /// and finish reasons. The stream completes when generation finishes.
562    pub stream: std::pin::Pin<Box<dyn Stream<Item = Result<StreamPart>> + Send>>,
563
564    /// Information about the request sent to the provider.
565    pub request: Option<RequestInfo>,
566
567    /// Information about the initial response from the provider.
568    ///
569    /// Contains response headers and metadata available before streaming begins.
570    pub response: Option<ResponseInfo>,
571}
572
573/// Information about the request sent to the provider.
574///
575/// Captures details about the outbound request for debugging and auditing purposes.
576#[derive(Debug, Clone)]
577pub struct RequestInfo {
578    /// Raw request body sent to the provider's API.
579    ///
580    /// This is the JSON payload after conversion from `CallOptions` to the
581    /// provider's specific request format.
582    pub body: Option<serde_json::Value>,
583}
584
585/// Information about the response received from the provider.
586///
587/// Contains metadata and raw response data useful for debugging, logging,
588/// and understanding provider behavior.
589#[derive(Debug, Clone)]
590pub struct ResponseInfo {
591    /// HTTP response headers from the provider.
592    ///
593    /// May include rate limit information, request IDs, and other metadata.
594    pub headers: Option<SharedHeaders>,
595
596    /// Raw response body from the provider.
597    ///
598    /// The JSON payload before conversion to standardized response types.
599    pub body: Option<serde_json::Value>,
600
601    /// Unique identifier assigned by the provider to this response.
602    ///
603    /// Useful for support requests and correlating responses with provider logs.
604    pub id: Option<String>,
605
606    /// Timestamp when the provider processed this request.
607    ///
608    /// Format and timezone depend on the provider implementation.
609    pub timestamp: Option<String>,
610
611    /// Actual model used by the provider.
612    ///
613    /// May differ from the requested model if the provider performed
614    /// automatic model selection or substitution.
615    pub model_id: Option<String>,
616}
617
618#[cfg(test)]
619mod tests {
620    use super::*;
621
622    struct DummyModel;
623
624    #[async_trait]
625    impl LanguageModel for DummyModel {
626        fn provider(&self) -> &str {
627            "test"
628        }
629        fn model_id(&self) -> &str {
630            "dummy"
631        }
632
633        async fn do_generate(&self, _opts: CallOptions) -> Result<GenerateResponse> {
634            unimplemented!()
635        }
636
637        async fn do_stream(&self, _opts: CallOptions) -> Result<StreamResponse> {
638            unimplemented!()
639        }
640    }
641
642    #[test]
643    fn test_trait_implementation() {
644        let model = DummyModel;
645        assert_eq!(model.provider(), "test");
646        assert_eq!(model.specification_version(), "v3");
647    }
648}