ai_lib/client/
stream.rs

1//! Streaming request execution module.
2//!
3//! This module handles streaming chat completion requests.
4//! It supports:
5//! - Standard streaming
6//! - Cancellable streaming via `CancelHandle`
7//! - Backpressure control
8
9use super::AiClient;
10use crate::api::ChatCompletionChunk;
11use crate::rate_limiter::BackpressurePermit;
12use crate::types::{AiLibError, ChatCompletionRequest};
13use futures::stream::Stream;
14use futures::Future;
15use std::pin::Pin;
16use std::task::{Context, Poll};
17use tokio::sync::oneshot;
18use tracing::warn;
19
20/// Streaming response cancel handle
21pub struct CancelHandle {
22    sender: Option<oneshot::Sender<()>>,
23}
24
25impl CancelHandle {
26    /// Cancel streaming response
27    pub fn cancel(mut self) {
28        if let Some(sender) = self.sender.take() {
29            let _ = sender.send(());
30        }
31    }
32}
33
34/// Controllable streaming response
35struct ControlledStream {
36    inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
37    cancel_rx: Option<oneshot::Receiver<()>>,
38    // Hold a backpressure permit for the lifetime of the stream if present
39    _bp_permit: Option<BackpressurePermit>,
40}
41
42impl ControlledStream {
43    fn new_with_bp(
44        inner: Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
45        cancel_rx: Option<oneshot::Receiver<()>>,
46        bp_permit: Option<BackpressurePermit>,
47    ) -> Self {
48        Self {
49            inner,
50            cancel_rx,
51            _bp_permit: bp_permit,
52        }
53    }
54}
55
56impl Stream for ControlledStream {
57    type Item = Result<ChatCompletionChunk, AiLibError>;
58
59    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
60        use futures::stream::StreamExt;
61
62        // Check if cancelled
63        if let Some(ref mut cancel_rx) = self.cancel_rx {
64            match Future::poll(Pin::new(cancel_rx), cx) {
65                Poll::Ready(_) => {
66                    self.cancel_rx = None;
67                    return Poll::Ready(Some(Err(AiLibError::ProviderError(
68                        "Stream cancelled".to_string(),
69                    ))));
70                }
71                Poll::Pending => {}
72            }
73        }
74
75        // Poll inner stream
76        self.inner.poll_next_unpin(cx)
77    }
78}
79
80#[allow(unused_mut)]
81pub async fn chat_completion_stream(
82    client: &AiClient,
83    request: ChatCompletionRequest,
84) -> Result<
85    Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
86    AiLibError,
87> {
88    let provider = client.provider_id();
89    let mut processed_request = client.prepare_chat_request(request);
90    processed_request.stream = Some(true);
91    // Acquire backpressure permit if configured and hold it for the lifetime of the stream
92    let bp_permit: Option<BackpressurePermit> = if let Some(ctrl) = &client.backpressure {
93        match ctrl.acquire_permit().await {
94            Ok(p) => Some(p),
95            Err(_) => {
96                return Err(AiLibError::RateLimitExceeded(
97                    "Backpressure: no permits available".to_string(),
98                ))
99            }
100        }
101    } else {
102        None
103    };
104    loop {
105        match client.chat_provider.stream(processed_request.clone()).await {
106            Ok(inner) => {
107                let cs = ControlledStream::new_with_bp(inner, None, bp_permit);
108                return Ok(Box::new(cs));
109            }
110            Err(err) => {
111                if client.model_resolver.looks_like_invalid_model(&err) {
112                    if let Some(resolution) =
113                        client.fallback_model_after_invalid(&processed_request.model)
114                    {
115                        warn!(
116                            target = "ai_lib.model",
117                            provider = ?provider,
118                            failed_model = %processed_request.model,
119                            fallback_model = %resolution.model,
120                            source = ?resolution.source,
121                            "Retrying stream with fallback model"
122                        );
123                        processed_request.model = resolution.model;
124                        continue;
125                    }
126                    let decorated = client.model_resolver.decorate_invalid_model_error(
127                        provider,
128                        &processed_request.model,
129                        err,
130                    );
131                    return Err(decorated.with_context("AiClient::chat_completion_stream"));
132                }
133
134                return Err(err.with_context("AiClient::chat_completion_stream"));
135            }
136        }
137    }
138}
139
140pub async fn chat_completion_stream_with_cancel(
141    client: &AiClient,
142    request: ChatCompletionRequest,
143) -> Result<
144    (
145        Box<dyn Stream<Item = Result<ChatCompletionChunk, AiLibError>> + Send + Unpin>,
146        CancelHandle,
147    ),
148    AiLibError,
149> {
150    let provider = client.provider_id();
151    let mut processed_request = client.prepare_chat_request(request);
152    processed_request.stream = Some(true);
153    // Acquire backpressure permit if configured and hold it for the lifetime of the stream
154    let bp_permit: Option<BackpressurePermit> = if let Some(ctrl) = &client.backpressure {
155        match ctrl.acquire_permit().await {
156            Ok(p) => Some(p),
157            Err(_) => {
158                return Err(AiLibError::RateLimitExceeded(
159                    "Backpressure: no permits available".to_string(),
160                ))
161            }
162        }
163    } else {
164        None
165    };
166    let stream = loop {
167        match client.chat_provider.stream(processed_request.clone()).await {
168            Ok(stream) => break stream,
169            Err(err) => {
170                if client.model_resolver.looks_like_invalid_model(&err) {
171                    if let Some(resolution) =
172                        client.fallback_model_after_invalid(&processed_request.model)
173                    {
174                        warn!(
175                            target = "ai_lib.model",
176                            provider = ?provider,
177                            failed_model = %processed_request.model,
178                            fallback_model = %resolution.model,
179                            source = ?resolution.source,
180                            "Retrying cancellable stream with fallback model"
181                        );
182                        processed_request.model = resolution.model;
183                        continue;
184                    }
185                    let decorated = client.model_resolver.decorate_invalid_model_error(
186                        provider,
187                        &processed_request.model,
188                        err,
189                    );
190                    return Err(
191                        decorated.with_context("AiClient::chat_completion_stream_with_cancel")
192                    );
193                }
194                return Err(err.with_context("AiClient::chat_completion_stream_with_cancel"));
195            }
196        }
197    };
198    let (cancel_tx, cancel_rx) = oneshot::channel();
199    let cancel_handle = CancelHandle {
200        sender: Some(cancel_tx),
201    };
202
203    let controlled_stream = ControlledStream::new_with_bp(stream, Some(cancel_rx), bp_permit);
204    Ok((Box::new(controlled_stream), cancel_handle))
205}