Skip to main content

llm/providers/openai/
provider.rs

1use async_openai::{Client, config::Config, types::chat::CreateChatCompletionRequest};
2use async_stream;
3use std::error::Error;
4use tokio_stream::StreamExt;
5use tracing::{debug, error};
6
7use super::{
8    mappers::{map_messages, map_tools},
9    streaming::process_completion_stream,
10};
11use crate::{Context, LlmError, LlmResponseStream, StreamingModelProvider};
12
13/// A Provider that's compatible with `OpenAI`'s chat completion API
14/// Other providers (e.g. Ollama, Llama.cpp etc) that are "`OpenAI` compatible" should implement this trait
15pub trait OpenAiChatProvider {
16    type Config: Config + Clone + 'static;
17
18    fn client(&self) -> &Client<Self::Config>;
19    fn model(&self) -> &str;
20    fn provider_name(&self) -> &str;
21}
22
23impl<T: OpenAiChatProvider + Send + Sync> StreamingModelProvider for T {
24    fn stream_response(&self, context: &Context) -> LlmResponseStream {
25        let client = self.client().clone();
26        let model = self.model().to_string();
27        let prompt_cache_key = context.prompt_cache_key().map(String::from);
28        let messages = match map_messages(context.messages()) {
29            Ok(messages) => messages,
30            Err(e) => return Box::pin(async_stream::stream! { yield Err(e); }),
31        };
32        let message_count = messages.len();
33        let tools = if context.tools().is_empty() {
34            None
35        } else {
36            match map_tools(context.tools()) {
37                Ok(t) => Some(t),
38                Err(e) => return Box::pin(async_stream::stream! { yield Err(e); }),
39            }
40        };
41
42        Box::pin(async_stream::stream! {
43            debug!("Starting chat completion stream for model: {model}");
44
45            let req = CreateChatCompletionRequest {
46                model: model.clone(),
47                messages,
48                tools,
49                stream: Some(true),
50                prompt_cache_key,
51                ..Default::default()
52            };
53
54            debug!(
55                "Making request to Ollama API with model: {model} and {message_count} messages"
56            );
57
58            let stream = match client.chat().create_stream(req).await {
59                Ok(stream) => {
60                    debug!("Successfully created stream from Ollama API");
61                    stream
62                }
63                Err(e) => {
64                    error!("Failed to create stream from Ollama API: {:?}", e);
65
66                    // Check if it's a reqwest error with more details
67                    if let Some(reqwest_err) =
68                        e.source().and_then(|s| s.downcast_ref::<reqwest::Error>())
69                    {
70                        if let Some(url) = reqwest_err.url() {
71                            error!("Request URL was: {url}");
72                        }
73                        if let Some(status) = reqwest_err.status() {
74                            error!("HTTP status: {status}");
75                        }
76                    }
77
78                    yield Err(LlmError::ApiRequest(e.to_string()));
79                    return;
80                }
81            };
82
83            let stream = stream.map(|result| {
84                result.map_err(|e| LlmError::ApiError(e.to_string()))
85            });
86
87            let mut shared_stream = Box::pin(process_completion_stream(stream));
88            while let Some(result) = shared_stream.next().await {
89                yield result;
90            }
91        })
92    }
93
94    fn context_window(&self) -> Option<u32> {
95        None
96    }
97
98    fn display_name(&self) -> String {
99        let model = self.model();
100        if model.is_empty() { self.provider_name().to_string() } else { format!("{} ({model})", self.provider_name()) }
101    }
102}