ai_sdk_provider/language_model/trait_def.rs
1use super::*;
2use crate::{Error, Result, SharedHeaders, SharedProviderMetadata};
3use async_trait::async_trait;
4use std::collections::HashMap;
5use std::future::{Future, IntoFuture};
6use std::pin::Pin;
7use std::sync::Arc;
8use tokio_stream::Stream;
9
10macro_rules! impl_call_options {
11 () => {
12 /// Sets the maximum number of tokens to generate in the model output.
13 ///
14 /// This parameter controls the maximum length of the generated response. The actual
15 /// number of tokens generated may be less if the model reaches a natural stopping point
16 /// or encounters a stop sequence.
17 pub fn max_tokens(mut self, max_tokens: u32) -> Self {
18 self.options.max_output_tokens = Some(max_tokens);
19 self
20 }
21
22 /// Sets the sampling temperature for controlling output randomness.
23 ///
24 /// Valid range: 0.0 to 2.0
25 /// - Lower values (e.g., 0.2): More focused, deterministic outputs
26 /// - Higher values (e.g., 1.5): More creative, diverse outputs
27 /// - Default: Provider-specific (typically 1.0)
28 pub fn temperature(mut self, temperature: f32) -> Self {
29 self.options.temperature = Some(temperature);
30 self
31 }
32
33 /// Sets sequences that will cause generation to stop when encountered.
34 ///
35 /// When the model generates any of these sequences, generation will terminate
36 /// immediately. Useful for structured output or preventing unwanted continuations.
37 pub fn stop_sequences(mut self, sequences: Vec<String>) -> Self {
38 self.options.stop_sequences = Some(sequences);
39 self
40 }
41
42 /// Sets the nucleus sampling probability mass.
43 ///
44 /// Valid range: 0.0 to 1.0
45 /// Only tokens whose cumulative probability exceeds this threshold are considered
46 /// for sampling. Lower values produce more focused outputs by restricting the
47 /// token pool to high-probability candidates.
48 pub fn top_p(mut self, top_p: f32) -> Self {
49 self.options.top_p = Some(top_p);
50 self
51 }
52
53 /// Sets the top-k sampling parameter.
54 ///
55 /// Limits sampling to the k most probable next tokens. Reduces output diversity
56 /// by eliminating low-probability tokens from consideration. Provider support varies.
57 pub fn top_k(mut self, top_k: u32) -> Self {
58 self.options.top_k = Some(top_k);
59 self
60 }
61
62 /// Sets the presence penalty for reducing topic repetition.
63 ///
64 /// Valid range: -2.0 to 2.0
65 /// Positive values penalize tokens that have already appeared in the generated text,
66 /// encouraging the model to explore new topics. Negative values have the opposite effect.
67 pub fn presence_penalty(mut self, penalty: f32) -> Self {
68 self.options.presence_penalty = Some(penalty);
69 self
70 }
71
72 /// Sets the frequency penalty for reducing token repetition.
73 ///
74 /// Valid range: -2.0 to 2.0
75 /// Positive values penalize tokens based on their frequency in the generated text,
76 /// with stronger penalties for frequently occurring tokens. This reduces repetitive phrasing.
77 pub fn frequency_penalty(mut self, penalty: f32) -> Self {
78 self.options.frequency_penalty = Some(penalty);
79 self
80 }
81
82 /// Sets a random seed for deterministic generation.
83 ///
84 /// Using the same seed with identical parameters should produce consistent outputs
85 /// across multiple requests. Useful for reproducible testing and debugging.
86 /// Provider support for deterministic generation varies.
87 pub fn seed(mut self, seed: i64) -> Self {
88 self.options.seed = Some(seed);
89 self
90 }
91
92 /// Sets the tools available for the model to call during generation.
93 ///
94 /// Tools enable function calling where the model can request execution of
95 /// external functions with structured arguments. The application is responsible
96 /// for executing tool calls and providing results back to the model.
97 pub fn tools(mut self, tools: Vec<Tool>) -> Self {
98 self.options.tools = Some(tools);
99 self
100 }
101
102 /// Sets the tool selection strategy.
103 ///
104 /// Controls how the model decides whether and which tools to call:
105 /// - `Auto`: Model autonomously decides when to use tools
106 /// - `Required`: Model must call at least one tool
107 /// - `None`: Tools are disabled for this request
108 /// - `Tool { name }`: Model must call the specified tool
109 pub fn tool_choice(mut self, choice: ToolChoice) -> Self {
110 self.options.tool_choice = Some(choice);
111 self
112 }
113
114 /// Forces a specific output format for the generated response.
115 ///
116 /// Constrains the model to produce output in the specified format:
117 /// - `ResponseFormat::Text`: Plain text output (default)
118 /// - `ResponseFormat::Json`: Valid JSON object or array
119 ///
120 /// JSON format is useful for structured data extraction and schema-compliant outputs.
121 pub fn response_format(mut self, format: ResponseFormat) -> Self {
122 self.options.response_format = Some(format);
123 self
124 }
125
126 /// Sets custom HTTP headers for the provider request.
127 ///
128 /// Replaces any existing headers. Use this for provider-specific headers such as
129 /// organization identifiers, API version selection, or custom authentication schemes.
130 pub fn headers(mut self, headers: HashMap<String, String>) -> Self {
131 self.options.headers = Some(headers);
132 self
133 }
134
135 /// Adds a single HTTP header to the provider request.
136 ///
137 /// Merges with existing headers rather than replacing them. Useful for
138 /// incrementally building header sets in a builder pattern.
139 pub fn header(mut self, key: impl Into<String>, value: impl Into<String>) -> Self {
140 let headers = self.options.headers.get_or_insert_with(HashMap::new);
141 headers.insert(key.into(), value.into());
142 self
143 }
144 };
145}
146
147/// Builder for configuring and executing non-streaming generation requests.
148///
149/// This builder provides a fluent interface for setting generation parameters before
150/// executing the request. It supports both explicit execution via `.send()` and
151/// implicit execution by awaiting the builder directly.
152#[derive(Clone)]
153pub struct GenerateBuilder<'a, M: LanguageModel + ?Sized> {
154 model: &'a M,
155 options: CallOptions,
156 init_error: Option<Arc<Error>>,
157}
158
159impl<'a, M: LanguageModel + ?Sized> GenerateBuilder<'a, M> {
160 /// Creates a new generation builder with the specified prompt.
161 ///
162 /// If the prompt conversion fails, the error is captured and will be returned
163 /// when the builder is executed.
164 pub fn new(model: &'a M, prompt_result: std::result::Result<Prompt, Error>) -> Self {
165 match prompt_result {
166 Ok(prompt) => Self {
167 model,
168 options: CallOptions {
169 prompt,
170 ..Default::default()
171 },
172 init_error: None,
173 },
174 Err(e) => Self {
175 model,
176 options: CallOptions::default(),
177 init_error: Some(Arc::new(e)),
178 },
179 }
180 }
181
182 impl_call_options!();
183
184 /// Executes the generation request and returns the complete response.
185 ///
186 /// This method consumes the builder and performs the actual API call to the
187 /// language model provider. Use this for explicit execution, or simply await
188 /// the builder directly for implicit execution.
189 pub async fn send(self) -> Result<GenerateResponse> {
190 if let Some(err) = self.init_error {
191 // We have to return the error. Since we can't move out of Arc,
192 // and Error is a Box<dyn std::error::Error>, we can recreate it
193 // from the string representation if we can't clone it.
194 // Or better, we just return a new generic error wrapping the old one.
195 return Err(format!("Initialization error: {}", err).into());
196 }
197 self.model.do_generate(self.options).await
198 }
199}
200
201/// Enables awaiting the builder directly without calling `.send()`.
202///
203/// This implementation allows the syntax `model.generate(prompt).await` as a
204/// shorthand for `model.generate(prompt).send().await`.
205impl<'a, M: LanguageModel + ?Sized> IntoFuture for GenerateBuilder<'a, M> {
206 type Output = Result<GenerateResponse>;
207 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
208
209 fn into_future(self) -> Self::IntoFuture {
210 if let Some(err) = self.init_error {
211 return Box::pin(async move { Err(format!("Initialization error: {}", err).into()) });
212 }
213 let model = self.model;
214 let options = self.options;
215 Box::pin(async move { model.do_generate(options).await })
216 }
217}
218
219/// Builder for configuring and executing streaming generation requests.
220///
221/// This builder provides a fluent interface for setting generation parameters before
222/// initiating a streaming response. Streaming allows processing partial responses as
223/// they arrive, enabling real-time display and reduced time-to-first-token.
224#[derive(Clone)]
225pub struct StreamBuilder<'a, M: LanguageModel + ?Sized> {
226 model: &'a M,
227 options: CallOptions,
228 init_error: Option<Arc<Error>>,
229}
230
231impl<'a, M: LanguageModel + ?Sized> StreamBuilder<'a, M> {
232 /// Creates a new streaming builder with the specified prompt.
233 ///
234 /// If the prompt conversion fails, the error is captured and will be returned
235 /// when the builder is executed.
236 pub fn new(model: &'a M, prompt_result: std::result::Result<Prompt, Error>) -> Self {
237 match prompt_result {
238 Ok(prompt) => Self {
239 model,
240 options: CallOptions {
241 prompt,
242 ..Default::default()
243 },
244 init_error: None,
245 },
246 Err(e) => Self {
247 model,
248 options: CallOptions::default(),
249 init_error: Some(Arc::new(e)),
250 },
251 }
252 }
253
254 impl_call_options!();
255
256 /// Initiates the streaming request and returns a stream of response parts.
257 ///
258 /// This method consumes the builder and establishes the streaming connection.
259 /// Use this for explicit execution, or simply await the builder directly.
260 pub async fn send(self) -> Result<StreamResponse> {
261 if let Some(err) = self.init_error {
262 return Err(format!("Initialization error: {}", err).into());
263 }
264 self.model.do_stream(self.options).await
265 }
266}
267
268/// Enables awaiting the builder directly without calling `.send()`.
269///
270/// This implementation allows the syntax `model.stream(prompt).await` as a
271/// shorthand for `model.stream(prompt).send().await`.
272impl<'a, M: LanguageModel + ?Sized> IntoFuture for StreamBuilder<'a, M> {
273 type Output = Result<StreamResponse>;
274 type IntoFuture = Pin<Box<dyn Future<Output = Self::Output> + Send + 'a>>;
275
276 fn into_future(self) -> Self::IntoFuture {
277 if let Some(err) = self.init_error {
278 return Box::pin(async move { Err(format!("Initialization error: {}", err).into()) });
279 }
280 let model = self.model;
281 let options = self.options;
282 Box::pin(async move { model.do_stream(options).await })
283 }
284}
285
286/// Core trait for language model providers.
287///
288/// This trait defines the standard interface that all language model implementations
289/// must satisfy. It provides both high-level builder methods for ergonomic usage and
290/// low-level `do_*` methods for direct provider integration.
291///
292/// # Implementation Requirements
293///
294/// Implementors must provide:
295/// - `provider()`: Provider identifier (e.g., "openai", "anthropic")
296/// - `model_id()`: Model identifier (e.g., "gpt-4", "claude-3-opus")
297/// - `do_generate()`: Non-streaming generation implementation
298/// - `do_stream()`: Streaming generation implementation
299///
300/// # Example Implementation
301///
302/// ```rust,ignore
303/// use ai_sdk_provider::{LanguageModel, CallOptions, GenerateResponse};
304/// use async_trait::async_trait;
305///
306/// struct MyModel {
307/// api_key: String,
308/// }
309///
310/// #[async_trait]
311/// impl LanguageModel for MyModel {
312/// fn provider(&self) -> &str { "my-provider" }
313/// fn model_id(&self) -> &str { "my-model-v1" }
314///
315/// async fn do_generate(&self, options: CallOptions) -> Result<GenerateResponse> {
316/// // Implementation
317/// }
318///
319/// async fn do_stream(&self, options: CallOptions) -> Result<StreamResponse> {
320/// // Implementation
321/// }
322/// }
323/// ```
324#[async_trait]
325pub trait LanguageModel: Send + Sync {
326 /// Returns the specification version implemented by this model.
327 ///
328 /// The default implementation returns "v3". Override this only if implementing
329 /// a different specification version.
330 fn specification_version(&self) -> &str {
331 "v3"
332 }
333
334 /// Returns the provider identifier for this model.
335 ///
336 /// This should be a stable, lowercase identifier for the AI provider
337 /// (e.g., "openai", "anthropic", "cohere").
338 fn provider(&self) -> &str;
339
340 /// Returns the specific model identifier.
341 ///
342 /// This should be the exact model name as recognized by the provider
343 /// (e.g., "gpt-4-turbo", "claude-3-opus-20240229").
344 fn model_id(&self) -> &str;
345
346 /// Returns URLs supported by this model for various operations.
347 ///
348 /// The default implementation returns an empty map. Providers can override this
349 /// to expose endpoint URLs for debugging or advanced configuration.
350 async fn supported_urls(&self) -> HashMap<String, Vec<String>> {
351 HashMap::new()
352 }
353
354 /// Creates a builder for a non-streaming generation request.
355 ///
356 /// This high-level method provides a fluent interface for configuring and executing
357 /// generation requests. The builder supports both explicit execution via `.send()`
358 /// and implicit execution by awaiting the builder directly.
359 ///
360 /// # Example
361 ///
362 /// ```rust,ignore
363 /// let response = model.generate("Explain photosynthesis")
364 /// .temperature(0.7)
365 /// .max_tokens(500)
366 /// .await?;
367 /// println!("{}", response.text);
368 /// ```
369 fn generate<P>(&self, prompt: P) -> GenerateBuilder<'_, Self>
370 where
371 Self: Sized,
372 P: TryInto<Prompt>,
373 Error: From<P::Error>,
374 {
375 let result = prompt.try_into().map_err(Error::from);
376 GenerateBuilder::new(self, result)
377 }
378
379 /// Creates a builder for a streaming generation request.
380 ///
381 /// Streaming enables processing partial responses as they arrive from the provider,
382 /// allowing for real-time display and reduced latency to first token. The stream
383 /// yields `StreamPart` items containing text deltas, tool calls, and metadata.
384 ///
385 /// # Example
386 ///
387 /// ```rust,ignore
388 /// use tokio_stream::StreamExt;
389 ///
390 /// let mut stream = model.stream("Write a poem")
391 /// .max_tokens(500)
392 /// .await?;
393 ///
394 /// while let Some(part) = stream.next().await {
395 /// match part? {
396 /// StreamPart::TextDelta(delta) => print!("{}", delta),
397 /// _ => {}
398 /// }
399 /// }
400 /// ```
401 fn stream<P>(&self, prompt: P) -> StreamBuilder<'_, Self>
402 where
403 Self: Sized,
404 P: TryInto<Prompt>,
405 Error: From<P::Error>,
406 {
407 let result = prompt.try_into().map_err(Error::from);
408 StreamBuilder::new(self, result)
409 }
410
411 /// Executes a non-streaming generation request.
412 ///
413 /// This method must be implemented by providers to perform the actual API call.
414 /// It receives fully configured `CallOptions` and returns a complete `GenerateResponse`
415 /// containing all generated content, usage statistics, and metadata.
416 ///
417 /// Most users should prefer the high-level `generate()` method over calling this directly.
418 async fn do_generate(&self, options: CallOptions) -> Result<GenerateResponse>;
419
420 /// Executes a streaming generation request.
421 ///
422 /// This method must be implemented by providers to establish a streaming connection.
423 /// It receives fully configured `CallOptions` and returns a `StreamResponse` containing
424 /// an async stream of response parts.
425 ///
426 /// Most users should prefer the high-level `stream()` method over calling this directly.
427 async fn do_stream(&self, options: CallOptions) -> Result<StreamResponse>;
428}
429
430#[async_trait]
431impl<T: LanguageModel + ?Sized> LanguageModel for Box<T> {
432 fn specification_version(&self) -> &str {
433 (**self).specification_version()
434 }
435
436 fn provider(&self) -> &str {
437 (**self).provider()
438 }
439 fn model_id(&self) -> &str {
440 (**self).model_id()
441 }
442
443 async fn supported_urls(&self) -> HashMap<String, Vec<String>> {
444 (**self).supported_urls().await
445 }
446
447 async fn do_generate(&self, options: CallOptions) -> Result<GenerateResponse> {
448 (**self).do_generate(options).await
449 }
450
451 async fn do_stream(&self, options: CallOptions) -> Result<StreamResponse> {
452 (**self).do_stream(options).await
453 }
454}
455
456#[async_trait]
457impl<T: LanguageModel + ?Sized> LanguageModel for std::sync::Arc<T> {
458 fn specification_version(&self) -> &str {
459 (**self).specification_version()
460 }
461
462 fn provider(&self) -> &str {
463 (**self).provider()
464 }
465 fn model_id(&self) -> &str {
466 (**self).model_id()
467 }
468
469 async fn supported_urls(&self) -> HashMap<String, Vec<String>> {
470 (**self).supported_urls().await
471 }
472
473 async fn do_generate(&self, options: CallOptions) -> Result<GenerateResponse> {
474 (**self).do_generate(options).await
475 }
476
477 async fn do_stream(&self, options: CallOptions) -> Result<StreamResponse> {
478 (**self).do_stream(options).await
479 }
480}
481
482#[async_trait]
483impl<T: LanguageModel + ?Sized> LanguageModel for &T {
484 fn specification_version(&self) -> &str {
485 (**self).specification_version()
486 }
487
488 fn provider(&self) -> &str {
489 (**self).provider()
490 }
491 fn model_id(&self) -> &str {
492 (**self).model_id()
493 }
494
495 async fn supported_urls(&self) -> HashMap<String, Vec<String>> {
496 (**self).supported_urls().await
497 }
498
499 async fn do_generate(&self, options: CallOptions) -> Result<GenerateResponse> {
500 (**self).do_generate(options).await
501 }
502
503 async fn do_stream(&self, options: CallOptions) -> Result<StreamResponse> {
504 (**self).do_stream(options).await
505 }
506}
507
508/// Complete response from a non-streaming generation request.
509///
510/// Contains all generated content, usage statistics, and provider metadata
511/// from a single generation request. The response includes information about
512/// why generation stopped and any warnings from the provider.
513#[derive(Debug, Clone)]
514pub struct GenerateResponse {
515 /// Generated content parts (text, tool calls, etc.).
516 pub content: Vec<Content>,
517
518 /// Reason why generation terminated.
519 ///
520 /// Indicates whether generation stopped naturally, reached token limits,
521 /// encountered a stop sequence, or was interrupted for other reasons.
522 pub finish_reason: FinishReason,
523
524 /// Token usage statistics for this generation.
525 ///
526 /// Includes prompt tokens, completion tokens, and total tokens consumed.
527 /// Used for cost tracking and quota management.
528 pub usage: Usage,
529
530 /// Provider-specific metadata and additional information.
531 ///
532 /// May include provider-specific fields not covered by the standard response.
533 pub provider_metadata: Option<SharedProviderMetadata>,
534
535 /// Information about the request sent to the provider.
536 ///
537 /// Useful for debugging and auditing. May include the raw request body.
538 pub request: Option<RequestInfo>,
539
540 /// Information about the response from the provider.
541 ///
542 /// Contains response headers, raw body, provider-assigned IDs, and timestamps.
543 pub response: Option<ResponseInfo>,
544
545 /// Warnings issued by the provider during generation.
546 ///
547 /// Non-fatal issues such as unsupported parameters or deprecated features.
548 /// The request proceeds despite these warnings.
549 pub warnings: Vec<CallWarning>,
550}
551
552/// Response from a streaming generation request.
553///
554/// Contains an async stream that yields response parts as they arrive from the
555/// provider. Each item in the stream represents a chunk of generated content,
556/// tool calls, or metadata updates.
557pub struct StreamResponse {
558 /// Async stream of response parts.
559 ///
560 /// Yields `StreamPart` items containing text deltas, tool calls, usage updates,
561 /// and finish reasons. The stream completes when generation finishes.
562 pub stream: std::pin::Pin<Box<dyn Stream<Item = Result<StreamPart>> + Send>>,
563
564 /// Information about the request sent to the provider.
565 pub request: Option<RequestInfo>,
566
567 /// Information about the initial response from the provider.
568 ///
569 /// Contains response headers and metadata available before streaming begins.
570 pub response: Option<ResponseInfo>,
571}
572
573/// Information about the request sent to the provider.
574///
575/// Captures details about the outbound request for debugging and auditing purposes.
576#[derive(Debug, Clone)]
577pub struct RequestInfo {
578 /// Raw request body sent to the provider's API.
579 ///
580 /// This is the JSON payload after conversion from `CallOptions` to the
581 /// provider's specific request format.
582 pub body: Option<serde_json::Value>,
583}
584
585/// Information about the response received from the provider.
586///
587/// Contains metadata and raw response data useful for debugging, logging,
588/// and understanding provider behavior.
589#[derive(Debug, Clone)]
590pub struct ResponseInfo {
591 /// HTTP response headers from the provider.
592 ///
593 /// May include rate limit information, request IDs, and other metadata.
594 pub headers: Option<SharedHeaders>,
595
596 /// Raw response body from the provider.
597 ///
598 /// The JSON payload before conversion to standardized response types.
599 pub body: Option<serde_json::Value>,
600
601 /// Unique identifier assigned by the provider to this response.
602 ///
603 /// Useful for support requests and correlating responses with provider logs.
604 pub id: Option<String>,
605
606 /// Timestamp when the provider processed this request.
607 ///
608 /// Format and timezone depend on the provider implementation.
609 pub timestamp: Option<String>,
610
611 /// Actual model used by the provider.
612 ///
613 /// May differ from the requested model if the provider performed
614 /// automatic model selection or substitution.
615 pub model_id: Option<String>,
616}
617
618#[cfg(test)]
619mod tests {
620 use super::*;
621
622 struct DummyModel;
623
624 #[async_trait]
625 impl LanguageModel for DummyModel {
626 fn provider(&self) -> &str {
627 "test"
628 }
629 fn model_id(&self) -> &str {
630 "dummy"
631 }
632
633 async fn do_generate(&self, _opts: CallOptions) -> Result<GenerateResponse> {
634 unimplemented!()
635 }
636
637 async fn do_stream(&self, _opts: CallOptions) -> Result<StreamResponse> {
638 unimplemented!()
639 }
640 }
641
642 #[test]
643 fn test_trait_implementation() {
644 let model = DummyModel;
645 assert_eq!(model.provider(), "test");
646 assert_eq!(model.specification_version(), "v3");
647 }
648}