Skip to main content

inference_core/
runner.rs

1//! `ModelRunner` — the trait every runtime backend implements.
2//!
3//! This is the seam that makes the actor decomposition work for both
4//! local-GPU and remote-network runtimes. Doc §5.4. The trait is
5//! deliberately small; backend-specific scheduling lives inside the
6//! runner's `execute` body.
7
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use futures::stream::BoxStream;
12
13use crate::batch::ExecuteBatch;
14use crate::deployment::RateLimits;
15use crate::error::{InferenceError, InferenceResult};
16use crate::runtime::{RuntimeKind, TransportKind};
17use crate::tokens::TokenChunk;
18
19/// The result of `ModelRunner::execute`. Local runtimes typically
20/// return `Streaming` even for unary calls (one final chunk); remote
21/// runtimes return `Streaming` for SSE responses and a single-chunk
22/// stream otherwise. Callers always treat it as a stream.
23pub struct RunHandle {
24    inner: BoxStream<'static, InferenceResult<TokenChunk>>,
25}
26
27impl RunHandle {
28    pub fn streaming(inner: BoxStream<'static, InferenceResult<TokenChunk>>) -> Self {
29        Self { inner }
30    }
31
32    pub fn into_stream(self) -> BoxStream<'static, InferenceResult<TokenChunk>> {
33        self.inner
34    }
35}
36
37impl std::fmt::Debug for RunHandle {
38    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39        f.debug_struct("RunHandle").finish_non_exhaustive()
40    }
41}
42
43/// Where to load weights from. Local runtimes implement; remote
44/// runtimes no-op.
45#[derive(Debug, Clone)]
46pub enum WeightSource {
47    HuggingFace {
48        repo: String,
49        revision: Option<String>,
50    },
51    LocalPath {
52        path: std::path::PathBuf,
53    },
54    /// The runtime knows how to fetch its own weights (vLLM, mistralrs).
55    RuntimeManaged,
56}
57
58/// Why a session rebuild was requested. Drives the runtime-specific
59/// rebuild behaviour described in §3.4.
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum SessionRebuildCause {
62    CudaContextPoisoned,
63    RemoteAuthFailure,
64    RemoteConfigChange,
65    Manual,
66}
67
68/// Opaque CUDA-context handle. Real local runtimes downcast to
69/// `Arc<rakka_accel::cuda::device::DeviceState>` (which itself wraps the
70/// `cudarc::driver::CudaContext`); tests and remote runtimes pass
71/// `None`. Kept type-erased so `inference-core` doesn't depend on
72/// `rakka-accel`/`cudarc` — preserves the §10.4 dependency budget so
73/// `inference --features remote-only` builds compile no GPU deps at
74/// all. Local-runtime crates downcast at the seam.
75pub type CudaContextHandle = Arc<dyn std::any::Any + Send + Sync>;
76
77#[async_trait]
78pub trait ModelRunner: Send + Sync {
79    /// Run an inference. For local runtimes, dispatches kernels; for
80    /// remote runtimes, sends an HTTP request. Returns immediately;
81    /// completion is observed via the returned `RunHandle` stream.
82    async fn execute(&mut self, batch: ExecuteBatch) -> InferenceResult<RunHandle>;
83
84    /// Local runtimes load weights to GPU; remote runtimes default to
85    /// a no-op.
86    async fn load_weights(
87        &mut self,
88        _ctx: Option<&CudaContextHandle>,
89        _source: WeightSource,
90    ) -> InferenceResult<()> {
91        Ok(())
92    }
93
94    /// Local runtimes rebuild after CUDA context poison; remote
95    /// runtimes rebuild after auth failure or config change.
96    async fn rebuild_session(&mut self, cause: SessionRebuildCause) -> InferenceResult<()>;
97
98    fn runtime_kind(&self) -> RuntimeKind;
99    fn transport_kind(&self) -> TransportKind;
100    fn gil_pinned(&self) -> bool {
101        matches!(self.runtime_kind(), RuntimeKind::Vllm | RuntimeKind::Python(_))
102    }
103
104    /// Rate-limit metadata. Returns `None` for local runtimes; remote
105    /// runtimes return their configured limits so the
106    /// `RateLimiterActor` can be initialized at deploy time.
107    fn rate_limits(&self) -> Option<&RateLimits> {
108        None
109    }
110
111    /// Best-effort cost estimate for the given batch (USD). Used by
112    /// `TieredRouter`-style actors and budget enforcement. Local
113    /// runtimes default to 0 (compute cost is amortized).
114    fn estimate_cost_usd(&self, _batch: &ExecuteBatch) -> f64 {
115        0.0
116    }
117}
118
119/// Helper: convert a generic error string to an `InferenceError`. Useful
120/// inside `RunHandle` stream futures that need to lift unrelated errors.
121pub fn lift_internal<E: std::fmt::Display>(err: E) -> InferenceError {
122    InferenceError::Internal(err.to_string())
123}