1use crate::LlmModel;
2use crate::ProviderConnectionConfig;
3use crate::Result as LlmResult;
4use std::future::Future;
5use std::pin::Pin;
6use tokio_stream::Stream;
7
8use super::{Context, LlmResponse};
9
10pub type LlmResponseStream = Pin<Box<dyn Stream<Item = LlmResult<LlmResponse>> + Send>>;
17
18#[doc = include_str!("docs/provider_factory.md")]
19pub trait ProviderFactory: Sized {
20 fn from_env() -> impl Future<Output = LlmResult<Self>> + Send;
22
23 fn from_env_with_connection(connection: ProviderConnectionConfig) -> impl Future<Output = LlmResult<Self>> + Send {
25 async move {
26 let _ = connection;
27 Self::from_env().await
28 }
29 }
30
31 fn with_model(self, model: &str) -> Self;
33}
34
35#[doc = include_str!("docs/streaming_model_provider.md")]
36pub trait StreamingModelProvider: Send + Sync {
37 fn stream_response(&self, context: &Context) -> LlmResponseStream;
38 fn display_name(&self) -> String;
39
40 fn context_window(&self) -> Option<u32>;
43
44 fn model(&self) -> Option<LlmModel> {
48 None
49 }
50}
51
52pub fn get_context_window(provider: &str, model_id: &str) -> Option<u32> {
56 let key = format!("{provider}:{model_id}");
57 key.parse::<LlmModel>().ok().and_then(|m| m.context_window())
58}
59
60impl StreamingModelProvider for Box<dyn StreamingModelProvider> {
61 fn stream_response(&self, context: &Context) -> LlmResponseStream {
62 (**self).stream_response(context)
63 }
64
65 fn display_name(&self) -> String {
66 (**self).display_name()
67 }
68
69 fn context_window(&self) -> Option<u32> {
70 (**self).context_window()
71 }
72
73 fn model(&self) -> Option<LlmModel> {
74 (**self).model()
75 }
76}
77
78impl<T: StreamingModelProvider + ?Sized> StreamingModelProvider for std::sync::Arc<T> {
79 fn stream_response(&self, context: &Context) -> LlmResponseStream {
80 (**self).stream_response(context)
81 }
82
83 fn display_name(&self) -> String {
84 (**self).display_name()
85 }
86
87 fn context_window(&self) -> Option<u32> {
88 (**self).context_window()
89 }
90
91 fn model(&self) -> Option<LlmModel> {
92 (**self).model()
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[test]
101 fn lookup_context_window_known_model() {
102 assert_eq!(get_context_window("anthropic", "claude-opus-4-6"), Some(1_000_000));
103 }
104
105 #[test]
106 fn lookup_context_window_openrouter_model() {
107 let result = get_context_window("openrouter", "anthropic/claude-opus-4");
109 assert_eq!(result, Some(200_000));
110 }
111
112 #[test]
113 fn lookup_context_window_unknown_model() {
114 assert_eq!(get_context_window("anthropic", "unknown-model-xyz"), None);
115 }
116
117 #[test]
118 fn lookup_context_window_unknown_provider() {
119 assert_eq!(get_context_window("unknown-provider", "some-model"), None);
120 }
121}