atomr_infer_runtime_litellm/
lib.rs1#![forbid(unsafe_code)]
16#![deny(rust_2018_idioms)]
17
18use std::sync::Arc;
19
20use arc_swap::ArcSwap;
21use async_trait::async_trait;
22use serde::{Deserialize, Serialize};
23use url::Url;
24
25use atomr_infer_core::batch::ExecuteBatch;
26use atomr_infer_core::deployment::{RateLimits, RetryPolicy, Timeouts};
27use atomr_infer_core::error::InferenceResult;
28use atomr_infer_core::runner::{ModelRunner, RunHandle, SessionRebuildCause};
29use atomr_infer_core::runtime::{CircuitBreakerConfig, ProviderKind, RuntimeKind, TransportKind};
30
31use atomr_infer_remote_core::session::SessionSnapshot;
32use atomr_infer_runtime_openai::{OpenAiConfig, OpenAiRunner, OpenAiVariant};
33
34#[derive(Debug, Clone, Serialize, Deserialize)]
35pub struct LiteLlmConfig {
36 pub endpoint: Url,
37 pub api_key: SecretRef,
38 #[serde(default)]
39 pub rate_limits: RateLimits,
40 #[serde(default = "default_retry")]
41 pub retry: RetryPolicy,
42 #[serde(default)]
43 pub circuit_breaker: CircuitBreakerConfig,
44 #[serde(default)]
45 pub timeouts: Timeouts,
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49#[serde(tag = "from", rename_all = "snake_case")]
50pub enum SecretRef {
51 Env { name: String },
52 File { path: std::path::PathBuf },
53 Inline { value: String },
54}
55
56fn default_retry() -> RetryPolicy {
57 RetryPolicy {
60 max_retries: 1,
61 ..RetryPolicy::default()
62 }
63}
64
65impl LiteLlmConfig {
66 pub fn into_openai(self, openai_secret: atomr_infer_runtime_openai::config::SecretRef) -> OpenAiConfig {
67 OpenAiConfig {
68 variant: OpenAiVariant::Direct {
69 endpoint: self.endpoint,
70 },
71 api_key: openai_secret,
72 organization: None,
73 project: None,
74 rate_limits: self.rate_limits,
75 retry: self.retry,
76 circuit_breaker: self.circuit_breaker,
77 timeouts: self.timeouts,
78 }
79 }
80}
81
82pub struct LiteLlmRunner {
86 inner: OpenAiRunner,
87}
88
89impl LiteLlmRunner {
90 pub fn new(config: OpenAiConfig, session: Arc<ArcSwap<SessionSnapshot>>) -> InferenceResult<Self> {
91 Ok(Self {
92 inner: OpenAiRunner::new(config, session)?,
93 })
94 }
95}
96
97#[async_trait]
98impl ModelRunner for LiteLlmRunner {
99 #[tracing::instrument(skip(self, batch), fields(request_id = %batch.request_id, model = %batch.model))]
100 async fn execute(&mut self, batch: ExecuteBatch) -> InferenceResult<RunHandle> {
101 self.inner.execute(batch).await
102 }
103
104 async fn rebuild_session(&mut self, cause: SessionRebuildCause) -> InferenceResult<()> {
105 self.inner.rebuild_session(cause).await
106 }
107
108 fn runtime_kind(&self) -> RuntimeKind {
109 RuntimeKind::LiteLlm
110 }
111 fn transport_kind(&self) -> TransportKind {
112 TransportKind::RemoteNetwork {
113 provider: ProviderKind::LiteLlm,
114 }
115 }
116 fn rate_limits(&self) -> Option<&RateLimits> {
117 self.inner.rate_limits()
118 }
119 fn estimate_cost_usd(&self, batch: &ExecuteBatch) -> f64 {
120 self.inner.estimate_cost_usd(batch)
124 }
125}