llm/providers/openai/
provider.rs1use async_openai::{Client, config::Config, types::chat::CreateChatCompletionRequest};
2use async_stream;
3use std::error::Error;
4use tokio_stream::StreamExt;
5use tracing::{debug, error};
6
7use super::{
8 mappers::{map_messages, map_tools},
9 streaming::process_completion_stream,
10};
11use crate::{Context, LlmError, LlmResponseStream, StreamingModelProvider};
12
13pub trait OpenAiChatProvider {
16 type Config: Config + Clone + 'static;
17
18 fn client(&self) -> &Client<Self::Config>;
19 fn model(&self) -> &str;
20 fn provider_name(&self) -> &str;
21}
22
23impl<T: OpenAiChatProvider + Send + Sync> StreamingModelProvider for T {
24 fn stream_response(&self, context: &Context) -> LlmResponseStream {
25 let client = self.client().clone();
26 let model = self.model().to_string();
27 let prompt_cache_key = context.prompt_cache_key().map(String::from);
28 let messages = match map_messages(context.messages()) {
29 Ok(messages) => messages,
30 Err(e) => return Box::pin(async_stream::stream! { yield Err(e); }),
31 };
32 let message_count = messages.len();
33 let tools = if context.tools().is_empty() {
34 None
35 } else {
36 match map_tools(context.tools()) {
37 Ok(t) => Some(t),
38 Err(e) => return Box::pin(async_stream::stream! { yield Err(e); }),
39 }
40 };
41
42 Box::pin(async_stream::stream! {
43 debug!("Starting chat completion stream for model: {model}");
44
45 let req = CreateChatCompletionRequest {
46 model: model.clone(),
47 messages,
48 tools,
49 stream: Some(true),
50 prompt_cache_key,
51 ..Default::default()
52 };
53
54 debug!(
55 "Making request to Ollama API with model: {model} and {message_count} messages"
56 );
57
58 let stream = match client.chat().create_stream(req).await {
59 Ok(stream) => {
60 debug!("Successfully created stream from Ollama API");
61 stream
62 }
63 Err(e) => {
64 error!("Failed to create stream from Ollama API: {:?}", e);
65
66 if let Some(reqwest_err) =
68 e.source().and_then(|s| s.downcast_ref::<reqwest::Error>())
69 {
70 if let Some(url) = reqwest_err.url() {
71 error!("Request URL was: {url}");
72 }
73 if let Some(status) = reqwest_err.status() {
74 error!("HTTP status: {status}");
75 }
76 }
77
78 yield Err(LlmError::ApiRequest(e.to_string()));
79 return;
80 }
81 };
82
83 let stream = stream.map(|result| {
84 result.map_err(|e| LlmError::ApiError(e.to_string()))
85 });
86
87 let mut shared_stream = Box::pin(process_completion_stream(stream));
88 while let Some(result) = shared_stream.next().await {
89 yield result;
90 }
91 })
92 }
93
94 fn context_window(&self) -> Option<u32> {
95 None
96 }
97
98 fn display_name(&self) -> String {
99 let model = self.model();
100 if model.is_empty() { self.provider_name().to_string() } else { format!("{} ({model})", self.provider_name()) }
101 }
102}