atomr_infer_core/runner.rs
1//! `ModelRunner` — the trait every runtime backend implements.
2//!
3//! This is the seam that makes the actor decomposition work for both
4//! local-GPU and remote-network runtimes. Doc §5.4. The trait is
5//! deliberately small; backend-specific scheduling lives inside the
6//! runner's `execute` body.
7
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use futures::stream::BoxStream;
12
13use crate::batch::ExecuteBatch;
14use crate::deployment::RateLimits;
15use crate::error::{InferenceError, InferenceResult};
16use crate::runtime::{RuntimeKind, TransportKind};
17use crate::tokens::TokenChunk;
18
19/// The result of `ModelRunner::execute`. Local runtimes typically
20/// return `Streaming` even for unary calls (one final chunk); remote
21/// runtimes return `Streaming` for SSE responses and a single-chunk
22/// stream otherwise. Callers always treat it as a stream.
23pub struct RunHandle {
24 inner: BoxStream<'static, InferenceResult<TokenChunk>>,
25}
26
27impl RunHandle {
28 pub fn streaming(inner: BoxStream<'static, InferenceResult<TokenChunk>>) -> Self {
29 Self { inner }
30 }
31
32 pub fn into_stream(self) -> BoxStream<'static, InferenceResult<TokenChunk>> {
33 self.inner
34 }
35}
36
37impl std::fmt::Debug for RunHandle {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 f.debug_struct("RunHandle").finish_non_exhaustive()
40 }
41}
42
43/// Where to load weights from. Local runtimes implement; remote
44/// runtimes no-op.
45#[derive(Debug, Clone)]
46#[non_exhaustive]
47pub enum WeightSource {
48 HuggingFace {
49 repo: String,
50 revision: Option<String>,
51 },
52 LocalPath {
53 path: std::path::PathBuf,
54 },
55 /// The runtime knows how to fetch its own weights (vLLM, mistralrs).
56 RuntimeManaged,
57}
58
59/// Why a session rebuild was requested. Drives the runtime-specific
60/// rebuild behaviour described in §3.4.
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62#[non_exhaustive]
63pub enum SessionRebuildCause {
64 CudaContextPoisoned,
65 RemoteAuthFailure,
66 RemoteConfigChange,
67 Manual,
68}
69
70/// Opaque CUDA-context handle. Real local runtimes downcast to
71/// `Arc<atomr_accel_cuda::device::DeviceState>` (which itself wraps the
72/// `cudarc::driver::CudaContext`); tests and remote runtimes pass
73/// `None`. Kept type-erased so `inference-core` doesn't depend on
74/// `atomr-accel`/`cudarc` — preserves the §10.4 dependency budget so
75/// `inference --features remote-only` builds compile no GPU deps at
76/// all. Local-runtime crates downcast at the seam.
77pub type CudaContextHandle = Arc<dyn std::any::Any + Send + Sync>;
78
79#[async_trait]
80pub trait ModelRunner: Send + Sync {
81 /// Run an inference. For local runtimes, dispatches kernels; for
82 /// remote runtimes, sends an HTTP request. Returns immediately;
83 /// completion is observed via the returned `RunHandle` stream.
84 async fn execute(&mut self, batch: ExecuteBatch) -> InferenceResult<RunHandle>;
85
86 /// Local runtimes load weights to GPU; remote runtimes default to
87 /// a no-op.
88 async fn load_weights(
89 &mut self,
90 _ctx: Option<&CudaContextHandle>,
91 _source: WeightSource,
92 ) -> InferenceResult<()> {
93 Ok(())
94 }
95
96 /// Local runtimes rebuild after CUDA context poison; remote
97 /// runtimes rebuild after auth failure or config change.
98 async fn rebuild_session(&mut self, cause: SessionRebuildCause) -> InferenceResult<()>;
99
100 fn runtime_kind(&self) -> RuntimeKind;
101 fn transport_kind(&self) -> TransportKind;
102 fn gil_pinned(&self) -> bool {
103 matches!(self.runtime_kind(), RuntimeKind::Vllm | RuntimeKind::Python(_))
104 }
105
106 /// Rate-limit metadata. Returns `None` for local runtimes; remote
107 /// runtimes return their configured limits so the
108 /// `RateLimiterActor` can be initialized at deploy time.
109 fn rate_limits(&self) -> Option<&RateLimits> {
110 None
111 }
112
113 /// Best-effort cost estimate for the given batch (USD). Used by
114 /// `TieredRouter`-style actors and budget enforcement. Local
115 /// runtimes default to 0 (compute cost is amortized).
116 fn estimate_cost_usd(&self, _batch: &ExecuteBatch) -> f64 {
117 0.0
118 }
119}
120
121/// Helper: convert a generic error string to an `InferenceError`. Useful
122/// inside `RunHandle` stream futures that need to lift unrelated errors.
123pub fn lift_internal<E: std::fmt::Display>(err: E) -> InferenceError {
124 InferenceError::Internal(err.to_string())
125}