Skip to main content

atomr_infer_core/
runner.rs

1//! `ModelRunner` — the trait every runtime backend implements.
2//!
3//! This is the seam that makes the actor decomposition work for both
4//! local-GPU and remote-network runtimes. Doc §5.4. The trait is
5//! deliberately small; backend-specific scheduling lives inside the
6//! runner's `execute` body.
7
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use futures::stream::BoxStream;
12
13use crate::batch::ExecuteBatch;
14use crate::deployment::RateLimits;
15use crate::error::{InferenceError, InferenceResult};
16use crate::runtime::{RuntimeKind, TransportKind};
17use crate::tokens::TokenChunk;
18
19/// The result of `ModelRunner::execute`. Local runtimes typically
20/// return `Streaming` even for unary calls (one final chunk); remote
21/// runtimes return `Streaming` for SSE responses and a single-chunk
22/// stream otherwise. Callers always treat it as a stream.
23pub struct RunHandle {
24    inner: BoxStream<'static, InferenceResult<TokenChunk>>,
25}
26
27impl RunHandle {
28    pub fn streaming(inner: BoxStream<'static, InferenceResult<TokenChunk>>) -> Self {
29        Self { inner }
30    }
31
32    pub fn into_stream(self) -> BoxStream<'static, InferenceResult<TokenChunk>> {
33        self.inner
34    }
35}
36
37impl std::fmt::Debug for RunHandle {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        f.debug_struct("RunHandle").finish_non_exhaustive()
40    }
41}
42
43/// Where to load weights from. Local runtimes implement; remote
44/// runtimes no-op.
45#[derive(Debug, Clone)]
46#[non_exhaustive]
47pub enum WeightSource {
48    HuggingFace {
49        repo: String,
50        revision: Option<String>,
51    },
52    LocalPath {
53        path: std::path::PathBuf,
54    },
55    /// The runtime knows how to fetch its own weights (vLLM, mistralrs).
56    RuntimeManaged,
57}
58
59/// Why a session rebuild was requested. Drives the runtime-specific
60/// rebuild behaviour described in §3.4.
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62#[non_exhaustive]
63pub enum SessionRebuildCause {
64    CudaContextPoisoned,
65    RemoteAuthFailure,
66    RemoteConfigChange,
67    Manual,
68}
69
70/// Opaque CUDA-context handle. Real local runtimes downcast to
71/// `Arc<atomr_accel_cuda::device::DeviceState>` (which itself wraps the
72/// `cudarc::driver::CudaContext`); tests and remote runtimes pass
73/// `None`. Kept type-erased so `inference-core` doesn't depend on
74/// `atomr-accel`/`cudarc` — preserves the §10.4 dependency budget so
75/// `inference --features remote-only` builds compile no GPU deps at
76/// all. Local-runtime crates downcast at the seam.
77pub type CudaContextHandle = Arc<dyn std::any::Any + Send + Sync>;
78
79#[async_trait]
80pub trait ModelRunner: Send + Sync {
81    /// Run an inference. For local runtimes, dispatches kernels; for
82    /// remote runtimes, sends an HTTP request. Returns immediately;
83    /// completion is observed via the returned `RunHandle` stream.
84    async fn execute(&mut self, batch: ExecuteBatch) -> InferenceResult<RunHandle>;
85
86    /// Local runtimes load weights to GPU; remote runtimes default to
87    /// a no-op.
88    async fn load_weights(
89        &mut self,
90        _ctx: Option<&CudaContextHandle>,
91        _source: WeightSource,
92    ) -> InferenceResult<()> {
93        Ok(())
94    }
95
96    /// Local runtimes rebuild after CUDA context poison; remote
97    /// runtimes rebuild after auth failure or config change.
98    async fn rebuild_session(&mut self, cause: SessionRebuildCause) -> InferenceResult<()>;
99
100    fn runtime_kind(&self) -> RuntimeKind;
101    fn transport_kind(&self) -> TransportKind;
102    fn gil_pinned(&self) -> bool {
103        matches!(self.runtime_kind(), RuntimeKind::Vllm | RuntimeKind::Python(_))
104    }
105
106    /// Rate-limit metadata. Returns `None` for local runtimes; remote
107    /// runtimes return their configured limits so the
108    /// `RateLimiterActor` can be initialized at deploy time.
109    fn rate_limits(&self) -> Option<&RateLimits> {
110        None
111    }
112
113    /// Best-effort cost estimate for the given batch (USD). Used by
114    /// `TieredRouter`-style actors and budget enforcement. Local
115    /// runtimes default to 0 (compute cost is amortized).
116    fn estimate_cost_usd(&self, _batch: &ExecuteBatch) -> f64 {
117        0.0
118    }
119}
120
121/// Helper: convert a generic error string to an `InferenceError`. Useful
122/// inside `RunHandle` stream futures that need to lift unrelated errors.
123pub fn lift_internal<E: std::fmt::Display>(err: E) -> InferenceError {
124    InferenceError::Internal(err.to_string())
125}