1use super::AiClient;
10use crate::api::ChatCompletionChunk;
11use crate::rate_limiter::BackpressurePermit;
12use crate::types::{AiLibError, ChatCompletionRequest};
13use futures::stream::Stream;
14use futures::Future;
15use std::pin::Pin;
16use std::task::{Context, Poll};
17use tokio::sync::oneshot;
18use tracing::warn;
19
20pub struct CancelHandle {
22 sender: Option<oneshot::Sender<()>>,
23}
24
25impl CancelHandle {
26 pub fn cancel(mut self) {
28 if let Some(sender) = self.sender.take() {
29 let _ = sender.send(());
30 }
31 }
32}
33
34struct ControlledStream {
36 inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
37 cancel_rx: Option<oneshot::Receiver<()>>,
38 _bp_permit: Option<BackpressurePermit>,
40}
41
42impl ControlledStream {
43 fn new_with_bp(
44 inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
45 cancel_rx: Option<oneshot::Receiver<()>>,
46 bp_permit: Option<BackpressurePermit>,
47 ) -> Self {
48 Self {
49 inner,
50 cancel_rx,
51 _bp_permit: bp_permit,
52 }
53 }
54}
55
56impl Stream for ControlledStream {
57 type Item = Result<ChatCompletionChunk, AiLibError>;
58
59 fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
60 use futures::stream::StreamExt;
61
62 if let Some(ref mut cancel_rx) = self.cancel_rx {
64 match Future::poll(Pin::new(cancel_rx), cx) {
65 Poll::Ready(_) => {
66 self.cancel_rx = None;
67 return Poll::Ready(Some(Err(AiLibError::ProviderError(
68 "Stream cancelled".to_string(),
69 ))));
70 }
71 Poll::Pending => {}
72 }
73 }
74
75 self.inner.poll_next_unpin(cx)
77 }
78}
79
80#[allow(unused_mut)]
81pub async fn chat_completion_stream(
82 client: &AiClient,
83 request: ChatCompletionRequest,
84) -> Result<
85 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
86 AiLibError,
87> {
88 let provider = client.provider_id();
89 let mut processed_request = client.prepare_chat_request(request);
90 processed_request.stream = Some(true);
91 let bp_permit: Option<BackpressurePermit> = if let Some(ctrl) = &client.backpressure {
93 match ctrl.acquire_permit().await {
94 Ok(p) => Some(p),
95 Err(_) => {
96 return Err(AiLibError::RateLimitExceeded(
97 "Backpressure: no permits available".to_string(),
98 ))
99 }
100 }
101 } else {
102 None
103 };
104 loop {
105 match client.chat_provider.stream(processed_request.clone()).await {
106 Ok(inner) => {
107 let cs = ControlledStream::new_with_bp(inner, None, bp_permit);
108 return Ok(Box::new(cs));
109 }
110 Err(err) => {
111 if client.model_resolver.looks_like_invalid_model(&err) {
112 if let Some(resolution) =
113 client.fallback_model_after_invalid(&processed_request.model)
114 {
115 warn!(
116 target = "ai_lib.model",
117 provider = ?provider,
118 failed_model = %processed_request.model,
119 fallback_model = %resolution.model,
120 source = ?resolution.source,
121 "Retrying stream with fallback model"
122 );
123 processed_request.model = resolution.model;
124 continue;
125 }
126 let decorated = client.model_resolver.decorate_invalid_model_error(
127 provider,
128 &processed_request.model,
129 err,
130 );
131 return Err(decorated.with_context("AiClient::chat_completion_stream"));
132 }
133
134 return Err(err.with_context("AiClient::chat_completion_stream"));
135 }
136 }
137 }
138}
139
140pub async fn chat_completion_stream_with_cancel(
141 client: &AiClient,
142 request: ChatCompletionRequest,
143) -> Result<
144 (
145 Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
146 CancelHandle,
147 ),
148 AiLibError,
149> {
150 let provider = client.provider_id();
151 let mut processed_request = client.prepare_chat_request(request);
152 processed_request.stream = Some(true);
153 let bp_permit: Option<BackpressurePermit> = if let Some(ctrl) = &client.backpressure {
155 match ctrl.acquire_permit().await {
156 Ok(p) => Some(p),
157 Err(_) => {
158 return Err(AiLibError::RateLimitExceeded(
159 "Backpressure: no permits available".to_string(),
160 ))
161 }
162 }
163 } else {
164 None
165 };
166 let stream = loop {
167 match client.chat_provider.stream(processed_request.clone()).await {
168 Ok(stream) => break stream,
169 Err(err) => {
170 if client.model_resolver.looks_like_invalid_model(&err) {
171 if let Some(resolution) =
172 client.fallback_model_after_invalid(&processed_request.model)
173 {
174 warn!(
175 target = "ai_lib.model",
176 provider = ?provider,
177 failed_model = %processed_request.model,
178 fallback_model = %resolution.model,
179 source = ?resolution.source,
180 "Retrying cancellable stream with fallback model"
181 );
182 processed_request.model = resolution.model;
183 continue;
184 }
185 let decorated = client.model_resolver.decorate_invalid_model_error(
186 provider,
187 &processed_request.model,
188 err,
189 );
190 return Err(
191 decorated.with_context("AiClient::chat_completion_stream_with_cancel")
192 );
193 }
194 return Err(err.with_context("AiClient::chat_completion_stream_with_cancel"));
195 }
196 }
197 };
198 let (cancel_tx, cancel_rx) = oneshot::channel();
199 let cancel_handle = CancelHandle {
200 sender: Some(cancel_tx),
201 };
202
203 let controlled_stream = ControlledStream::new_with_bp(stream, Some(cancel_rx), bp_permit);
204 Ok((Box::new(controlled_stream), cancel_handle))
205}