Skip to main content

heartbit_core/llm/
mod.rs

1//! LLM provider abstractions — `LlmProvider` trait, Anthropic/Gemini/OpenRouter/OpenAI-compat backends, retry, cascade, and circuit-breaker wrappers.
2
3pub mod anthropic;
4pub mod circuit;
5
6/// Maximum bytes read from an upstream LLM error body.
7///
8/// SECURITY (F-LLM-5, F-LLM-6): a hostile or compromised provider can stream
9/// gigabytes of body in response to a 4xx/5xx and OOM the agent. We also
10/// truncate to keep accidentally-included secrets / internal IPs from
11/// flooding logs.
12const ERROR_BODY_MAX_BYTES: usize = 8 << 10; // 8 KiB
13
14/// SECURITY (F-LLM-4): hard cap on accumulated streaming text bytes per
15/// response. A drip-fed `text_delta` event sequence would otherwise grow
16/// unbounded.
17pub(crate) const STREAM_MAX_TEXT_BYTES: usize = 16 << 20; // 16 MiB
18
19/// SECURITY (F-LLM-4): hard cap on accumulated tool-call arguments JSON per
20/// individual tool call.
21pub(crate) const STREAM_MAX_TOOL_ARGS_BYTES: usize = 1 << 20; // 1 MiB
22
23/// SECURITY (F-LLM-4): hard cap on the number of tool calls a single
24/// streaming response may emit. Protects against a hostile `tool_calls[].index`
25/// of `u32::MAX` triggering a multi-billion-entry Vec allocation.
26pub(crate) const STREAM_MAX_TOOL_CALLS: usize = 256;
27
28/// Build an `Error::Api` from a failed HTTP response, sanitizing auth errors.
29///
30/// For 401/403 responses, the body is NOT read to avoid leaking API key
31/// fragments in logs. For all other statuses the response body is included
32/// but capped at `ERROR_BODY_MAX_BYTES` and stripped of control characters
33/// (newlines, ANSI escapes) so a hostile body cannot poison structured logs
34/// (F-LLM-6).
35pub(crate) async fn api_error_from_response(response: reqwest::Response) -> Error {
36    use futures::TryStreamExt;
37    let status = response.status().as_u16();
38    let message = if status == 401 || status == 403 {
39        format!("authentication failed (HTTP {status})")
40    } else {
41        let mut buf: Vec<u8> = Vec::with_capacity(2048);
42        let mut stream = response.bytes_stream();
43        let mut overflowed = false;
44        loop {
45            match stream.try_next().await {
46                Ok(Some(chunk)) => {
47                    let remaining = ERROR_BODY_MAX_BYTES.saturating_sub(buf.len());
48                    if remaining == 0 {
49                        overflowed = true;
50                        break;
51                    }
52                    let take = chunk.len().min(remaining);
53                    buf.extend_from_slice(&chunk[..take]);
54                    if take < chunk.len() {
55                        overflowed = true;
56                        break;
57                    }
58                }
59                Ok(None) => break,
60                Err(e) => {
61                    return Error::Api {
62                        status,
63                        message: format!("<body read error: {e}>"),
64                    };
65                }
66            }
67        }
68        let mut text = String::from_utf8_lossy(&buf).to_string();
69        // Strip control characters (CR/LF/ANSI ESC). Keep tabs and printable.
70        text.retain(|c| c == '\t' || (!c.is_control() && c != '\u{1b}'));
71        if overflowed {
72            text.push_str("…[truncated]");
73        }
74        text
75    };
76    Error::Api { status, message }
77}
78pub mod cascade;
79pub mod error_class;
80pub mod gemini;
81pub mod openai_compat;
82pub mod openrouter;
83pub mod pricing;
84pub mod registry;
85pub mod retry;
86pub mod types;
87
88use std::future::Future;
89use std::pin::Pin;
90use std::sync::Arc;
91
92use crate::error::Error;
93use crate::llm::types::{CompletionRequest, CompletionResponse};
94
95/// Callback invoked with each text delta during streaming.
96pub type OnText = dyn Fn(&str) + Send + Sync;
97
98/// Decision returned by the `OnApproval` callback.
99///
100/// `Allow` and `Deny` behave like the previous `true`/`false` return.
101/// `AlwaysAllow` and `AlwaysDeny` additionally persist the decision as a
102/// learned permission rule so it survives across sessions.
103#[derive(Debug, Clone, Copy, PartialEq, Eq)]
104pub enum ApprovalDecision {
105    /// Allow this time.
106    Allow,
107    /// Deny this time.
108    Deny,
109    /// Allow and persist as a permission rule.
110    AlwaysAllow,
111    /// Deny and persist as a permission rule.
112    AlwaysDeny,
113}
114
115impl ApprovalDecision {
116    /// Returns `true` when the decision allows execution.
117    pub fn is_allowed(self) -> bool {
118        matches!(self, Self::Allow | Self::AlwaysAllow)
119    }
120
121    /// Returns `true` when the decision should be persisted.
122    pub fn is_persistent(self) -> bool {
123        matches!(self, Self::AlwaysAllow | Self::AlwaysDeny)
124    }
125}
126
127impl From<bool> for ApprovalDecision {
128    fn from(allowed: bool) -> Self {
129        if allowed { Self::Allow } else { Self::Deny }
130    }
131}
132
133/// Callback invoked before tool execution for human-in-the-loop approval.
134///
135/// Receives the list of tool calls the LLM wants to execute.
136/// Returns an [`ApprovalDecision`] indicating whether to proceed.
137/// `AlwaysAllow`/`AlwaysDeny` additionally persist the decision as a
138/// learned permission rule.
139pub type OnApproval = dyn Fn(&[crate::llm::types::ToolCall]) -> ApprovalDecision + Send + Sync;
140
141/// Trait for LLM providers.
142///
143/// Uses RPITIT (`impl Future`) which means this trait is NOT dyn-compatible.
144/// All consumers are generic over `P: LlmProvider`. This is intentional:
145/// one provider per process, no need for trait objects.
146///
147/// For dynamic dispatch, use [`BoxedProvider`] which wraps any `LlmProvider`
148/// behind [`DynLlmProvider`].
149pub trait LlmProvider: Send + Sync {
150    /// Send a completion request and wait for the full response.
151    fn complete(
152        &self,
153        request: CompletionRequest,
154    ) -> impl Future<Output = Result<CompletionResponse, Error>> + Send;
155
156    /// Stream a completion, calling `on_text` for each text delta as it arrives.
157    ///
158    /// The returned `CompletionResponse` contains the full accumulated response
159    /// (same as `complete()`), but text was emitted incrementally via the callback.
160    ///
161    /// Default: falls back to `complete()` (no incremental streaming).
162    fn stream_complete(
163        &self,
164        request: CompletionRequest,
165        on_text: &OnText,
166    ) -> impl Future<Output = Result<CompletionResponse, Error>> + Send {
167        let _ = on_text;
168        self.complete(request)
169    }
170
171    /// Return the model identifier, if known.
172    ///
173    /// Used for audit trail events. Default returns `None`.
174    fn model_name(&self) -> Option<&str> {
175        None
176    }
177}
178
179// ---------------------------------------------------------------------------
180// DynLlmProvider — object-safe adapter for LlmProvider (RPITIT → dyn)
181// ---------------------------------------------------------------------------
182
183/// Object-safe version of [`LlmProvider`] for dynamic dispatch.
184///
185/// `LlmProvider` uses RPITIT (not dyn-compatible). This trait wraps it via
186/// `Pin<Box<dyn Future>>` so providers can be stored as `Arc<dyn DynLlmProvider>`.
187///
188/// A blanket impl covers all `LlmProvider` types automatically.
189///
190/// Used by the Restate service layer (`AgentServiceImpl`) and by
191/// [`BoxedProvider`] for type-erased standalone use.
192pub trait DynLlmProvider: Send + Sync {
193    /// Boxed-future version of [`LlmProvider::complete`] for object-safe dispatch.
194    fn complete<'a>(
195        &'a self,
196        request: CompletionRequest,
197    ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, Error>> + Send + 'a>>;
198
199    /// Boxed-future version of [`LlmProvider::stream_complete`] for object-safe dispatch.
200    fn stream_complete<'a>(
201        &'a self,
202        request: CompletionRequest,
203        on_text: &'a OnText,
204    ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, Error>> + Send + 'a>>;
205
206    /// Return the model identifier, if known.
207    fn model_name(&self) -> Option<&str>;
208}
209
210impl<P: LlmProvider> DynLlmProvider for P {
211    fn complete<'a>(
212        &'a self,
213        request: CompletionRequest,
214    ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, Error>> + Send + 'a>> {
215        Box::pin(LlmProvider::complete(self, request))
216    }
217
218    fn stream_complete<'a>(
219        &'a self,
220        request: CompletionRequest,
221        on_text: &'a OnText,
222    ) -> Pin<Box<dyn Future<Output = Result<CompletionResponse, Error>> + Send + 'a>> {
223        Box::pin(LlmProvider::stream_complete(self, request, on_text))
224    }
225
226    fn model_name(&self) -> Option<&str> {
227        LlmProvider::model_name(self)
228    }
229}
230
231// ---------------------------------------------------------------------------
232// BoxedProvider — type-erased LlmProvider via DynLlmProvider
233// ---------------------------------------------------------------------------
234
235/// Type-erased LLM provider for use when dynamic dispatch is needed.
236///
237/// Wraps any [`LlmProvider`] behind `Box<dyn DynLlmProvider>`. Implements
238/// `LlmProvider` itself, so it can be used with `AgentRunner<BoxedProvider>`
239/// and `Orchestrator<BoxedProvider>`, eliminating the need for generic code
240/// at the call site.
241///
242/// # Example
243///
244/// ```ignore
245/// let provider = BoxedProvider::new(AnthropicProvider::new(key, model));
246/// let runner = AgentRunner::builder(Arc::new(provider))
247///     .name("agent")
248///     .build()?;
249/// ```
250pub struct BoxedProvider(Box<dyn DynLlmProvider>);
251
252impl BoxedProvider {
253    /// Create a type-erased provider from any concrete `LlmProvider`.
254    pub fn new<P: LlmProvider + 'static>(provider: P) -> Self {
255        Self(Box::new(provider))
256    }
257
258    /// Create a type-erased provider from an `Arc<P>`.
259    ///
260    /// Useful when the provider is already behind an `Arc` (e.g., shared between
261    /// the orchestrator and sub-agents) and needs to be converted to `BoxedProvider`
262    /// for type erasure without consuming the original.
263    pub fn from_arc<P: LlmProvider + 'static>(provider: Arc<P>) -> Self {
264        /// Internal adapter: delegates to the `Arc<P>` inner provider.
265        struct ArcAdapter<P>(Arc<P>);
266
267        impl<P: LlmProvider> LlmProvider for ArcAdapter<P> {
268            async fn complete(
269                &self,
270                request: CompletionRequest,
271            ) -> Result<CompletionResponse, Error> {
272                self.0.complete(request).await
273            }
274
275            async fn stream_complete(
276                &self,
277                request: CompletionRequest,
278                on_text: &OnText,
279            ) -> Result<CompletionResponse, Error> {
280                self.0.stream_complete(request, on_text).await
281            }
282
283            fn model_name(&self) -> Option<&str> {
284                self.0.model_name()
285            }
286        }
287
288        Self(Box::new(ArcAdapter(provider)))
289    }
290}
291
292impl LlmProvider for BoxedProvider {
293    async fn complete(&self, request: CompletionRequest) -> Result<CompletionResponse, Error> {
294        self.0.complete(request).await
295    }
296
297    async fn stream_complete(
298        &self,
299        request: CompletionRequest,
300        on_text: &OnText,
301    ) -> Result<CompletionResponse, Error> {
302        self.0.stream_complete(request, on_text).await
303    }
304
305    fn model_name(&self) -> Option<&str> {
306        self.0.model_name()
307    }
308}
309
310#[cfg(test)]
311mod tests {
312    use super::*;
313    use crate::llm::types::{ContentBlock, Message, StopReason, TokenUsage};
314    use std::sync::{Arc, Mutex};
315
316    struct FakeProvider;
317
318    impl LlmProvider for FakeProvider {
319        async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse, Error> {
320            Ok(CompletionResponse {
321                content: vec![ContentBlock::Text {
322                    text: "fake".into(),
323                }],
324                stop_reason: StopReason::EndTurn,
325                usage: TokenUsage::default(),
326                model: None,
327            })
328        }
329    }
330
331    struct StreamingFakeProvider;
332
333    impl LlmProvider for StreamingFakeProvider {
334        async fn complete(&self, _request: CompletionRequest) -> Result<CompletionResponse, Error> {
335            panic!("should call stream_complete, not complete");
336        }
337
338        async fn stream_complete(
339            &self,
340            _request: CompletionRequest,
341            on_text: &OnText,
342        ) -> Result<CompletionResponse, Error> {
343            on_text("hello");
344            on_text(" world");
345            Ok(CompletionResponse {
346                content: vec![ContentBlock::Text {
347                    text: "hello world".into(),
348                }],
349                stop_reason: StopReason::EndTurn,
350                usage: TokenUsage::default(),
351                model: None,
352            })
353        }
354    }
355
356    fn test_request() -> CompletionRequest {
357        CompletionRequest {
358            system: String::new(),
359            messages: vec![Message::user("test")],
360            tools: vec![],
361            max_tokens: 100,
362            tool_choice: None,
363            reasoning_effort: None,
364        }
365    }
366
367    #[test]
368    fn dyn_llm_provider_wraps_provider() {
369        let provider = FakeProvider;
370        let dyn_provider: &dyn DynLlmProvider = &provider;
371        let _ = dyn_provider;
372    }
373
374    #[tokio::test]
375    async fn boxed_provider_delegates_complete() {
376        let provider = BoxedProvider::new(FakeProvider);
377        // Disambiguate: BoxedProvider implements both LlmProvider and DynLlmProvider
378        let response = LlmProvider::complete(&provider, test_request())
379            .await
380            .unwrap();
381        assert_eq!(response.text(), "fake");
382    }
383
384    #[tokio::test]
385    async fn boxed_provider_delegates_stream_complete() {
386        let provider = BoxedProvider::new(StreamingFakeProvider);
387        let received = Arc::new(Mutex::new(Vec::<String>::new()));
388        let received_clone = received.clone();
389        let on_text: &OnText = &move |text: &str| {
390            received_clone
391                .lock()
392                .expect("test lock")
393                .push(text.to_string());
394        };
395
396        let response = LlmProvider::stream_complete(&provider, test_request(), on_text)
397            .await
398            .unwrap();
399        assert_eq!(response.text(), "hello world");
400
401        let texts = received.lock().expect("test lock");
402        assert_eq!(*texts, vec!["hello", " world"]);
403    }
404
405    #[test]
406    fn boxed_provider_is_send_sync() {
407        fn assert_send_sync<T: Send + Sync>() {}
408        assert_send_sync::<BoxedProvider>();
409    }
410
411    #[tokio::test]
412    async fn boxed_provider_default_stream_falls_back_to_complete() {
413        // FakeProvider only implements complete; stream_complete should fall back
414        let provider = BoxedProvider::new(FakeProvider);
415        let on_text: &OnText = &|_| {};
416        let response = LlmProvider::stream_complete(&provider, test_request(), on_text)
417            .await
418            .unwrap();
419        assert_eq!(response.text(), "fake");
420    }
421
422    #[tokio::test]
423    async fn boxed_provider_from_arc_delegates_complete() {
424        let provider = Arc::new(FakeProvider);
425        let boxed = BoxedProvider::from_arc(provider);
426        let response = LlmProvider::complete(&boxed, test_request()).await.unwrap();
427        assert_eq!(response.text(), "fake");
428    }
429
430    #[tokio::test]
431    async fn boxed_provider_from_arc_delegates_stream_complete() {
432        let provider = Arc::new(StreamingFakeProvider);
433        let boxed = BoxedProvider::from_arc(provider);
434        let received = Arc::new(Mutex::new(Vec::<String>::new()));
435        let received_clone = received.clone();
436        let on_text: &OnText = &move |text: &str| {
437            received_clone
438                .lock()
439                .expect("test lock")
440                .push(text.to_string());
441        };
442        let response = LlmProvider::stream_complete(&boxed, test_request(), on_text)
443            .await
444            .unwrap();
445        assert_eq!(response.text(), "hello world");
446        let texts = received.lock().expect("test lock");
447        assert_eq!(*texts, vec!["hello", " world"]);
448    }
449
450    #[test]
451    fn model_name_default_is_none() {
452        let provider = FakeProvider;
453        assert!(LlmProvider::model_name(&provider).is_none());
454    }
455
456    #[test]
457    fn boxed_provider_preserves_model_name() {
458        struct NamedProvider;
459        impl LlmProvider for NamedProvider {
460            async fn complete(
461                &self,
462                _request: CompletionRequest,
463            ) -> Result<CompletionResponse, Error> {
464                unimplemented!()
465            }
466            fn model_name(&self) -> Option<&str> {
467                Some("test-model")
468            }
469        }
470        let boxed = BoxedProvider::new(NamedProvider);
471        assert_eq!(LlmProvider::model_name(&boxed), Some("test-model"));
472    }
473
474    #[test]
475    fn boxed_provider_from_arc_preserves_model_name() {
476        struct NamedProvider;
477        impl LlmProvider for NamedProvider {
478            async fn complete(
479                &self,
480                _request: CompletionRequest,
481            ) -> Result<CompletionResponse, Error> {
482                unimplemented!()
483            }
484            fn model_name(&self) -> Option<&str> {
485                Some("arc-model")
486            }
487        }
488        let boxed = BoxedProvider::from_arc(Arc::new(NamedProvider));
489        assert_eq!(LlmProvider::model_name(&boxed), Some("arc-model"));
490    }
491
492    #[tokio::test]
493    async fn boxed_provider_from_arc_shares_underlying_provider() {
494        // Verify from_arc shares the underlying provider via Arc (not a copy)
495        let call_count = Arc::new(std::sync::atomic::AtomicUsize::new(0));
496        struct CountingProvider(Arc<std::sync::atomic::AtomicUsize>);
497        impl LlmProvider for CountingProvider {
498            async fn complete(
499                &self,
500                _request: CompletionRequest,
501            ) -> Result<CompletionResponse, crate::error::Error> {
502                self.0.fetch_add(1, std::sync::atomic::Ordering::Relaxed);
503                Ok(CompletionResponse {
504                    content: vec![ContentBlock::Text {
505                        text: "counted".into(),
506                    }],
507                    stop_reason: StopReason::EndTurn,
508                    usage: TokenUsage::default(),
509                    model: None,
510                })
511            }
512        }
513
514        let inner = Arc::new(CountingProvider(call_count.clone()));
515        let boxed1 = BoxedProvider::from_arc(inner.clone());
516        let boxed2 = BoxedProvider::from_arc(inner);
517
518        LlmProvider::complete(&boxed1, test_request())
519            .await
520            .unwrap();
521        LlmProvider::complete(&boxed2, test_request())
522            .await
523            .unwrap();
524
525        assert_eq!(
526            call_count.load(std::sync::atomic::Ordering::Relaxed),
527            2,
528            "both boxed providers should share the same underlying provider"
529        );
530    }
531
532    // --- ApprovalDecision ---
533
534    #[test]
535    fn approval_decision_from_true() {
536        let decision = ApprovalDecision::from(true);
537        assert_eq!(decision, ApprovalDecision::Allow);
538        assert!(decision.is_allowed());
539        assert!(!decision.is_persistent());
540    }
541
542    #[test]
543    fn approval_decision_from_false() {
544        let decision = ApprovalDecision::from(false);
545        assert_eq!(decision, ApprovalDecision::Deny);
546        assert!(!decision.is_allowed());
547        assert!(!decision.is_persistent());
548    }
549
550    #[test]
551    fn approval_decision_always_allow() {
552        let decision = ApprovalDecision::AlwaysAllow;
553        assert!(decision.is_allowed());
554        assert!(decision.is_persistent());
555    }
556
557    #[test]
558    fn approval_decision_always_deny() {
559        let decision = ApprovalDecision::AlwaysDeny;
560        assert!(!decision.is_allowed());
561        assert!(decision.is_persistent());
562    }
563}