inference_core/runner.rs
1//! `ModelRunner` — the trait every runtime backend implements.
2//!
3//! This is the seam that makes the actor decomposition work for both
4//! local-GPU and remote-network runtimes. Doc §5.4. The trait is
5//! deliberately small; backend-specific scheduling lives inside the
6//! runner's `execute` body.
7
8use std::sync::Arc;
9
10use async_trait::async_trait;
11use futures::stream::BoxStream;
12
13use crate::batch::ExecuteBatch;
14use crate::deployment::RateLimits;
15use crate::error::{InferenceError, InferenceResult};
16use crate::runtime::{RuntimeKind, TransportKind};
17use crate::tokens::TokenChunk;
18
19/// The result of `ModelRunner::execute`. Local runtimes typically
20/// return `Streaming` even for unary calls (one final chunk); remote
21/// runtimes return `Streaming` for SSE responses and a single-chunk
22/// stream otherwise. Callers always treat it as a stream.
23pub struct RunHandle {
24 inner: BoxStream<'static, InferenceResult<TokenChunk>>,
25}
26
27impl RunHandle {
28 pub fn streaming(inner: BoxStream<'static, InferenceResult<TokenChunk>>) -> Self {
29 Self { inner }
30 }
31
32 pub fn into_stream(self) -> BoxStream<'static, InferenceResult<TokenChunk>> {
33 self.inner
34 }
35}
36
37impl std::fmt::Debug for RunHandle {
38 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
39 f.debug_struct("RunHandle").finish_non_exhaustive()
40 }
41}
42
43/// Where to load weights from. Local runtimes implement; remote
44/// runtimes no-op.
45#[derive(Debug, Clone)]
46pub enum WeightSource {
47 HuggingFace {
48 repo: String,
49 revision: Option<String>,
50 },
51 LocalPath {
52 path: std::path::PathBuf,
53 },
54 /// The runtime knows how to fetch its own weights (vLLM, mistralrs).
55 RuntimeManaged,
56}
57
58/// Why a session rebuild was requested. Drives the runtime-specific
59/// rebuild behaviour described in §3.4.
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum SessionRebuildCause {
62 CudaContextPoisoned,
63 RemoteAuthFailure,
64 RemoteConfigChange,
65 Manual,
66}
67
68/// Opaque CUDA-context handle. Real local runtimes downcast to
69/// `Arc<rakka_accel::cuda::device::DeviceState>` (which itself wraps the
70/// `cudarc::driver::CudaContext`); tests and remote runtimes pass
71/// `None`. Kept type-erased so `inference-core` doesn't depend on
72/// `rakka-accel`/`cudarc` — preserves the §10.4 dependency budget so
73/// `inference --features remote-only` builds compile no GPU deps at
74/// all. Local-runtime crates downcast at the seam.
75pub type CudaContextHandle = Arc<dyn std::any::Any + Send + Sync>;
76
77#[async_trait]
78pub trait ModelRunner: Send + Sync {
79 /// Run an inference. For local runtimes, dispatches kernels; for
80 /// remote runtimes, sends an HTTP request. Returns immediately;
81 /// completion is observed via the returned `RunHandle` stream.
82 async fn execute(&mut self, batch: ExecuteBatch) -> InferenceResult<RunHandle>;
83
84 /// Local runtimes load weights to GPU; remote runtimes default to
85 /// a no-op.
86 async fn load_weights(
87 &mut self,
88 _ctx: Option<&CudaContextHandle>,
89 _source: WeightSource,
90 ) -> InferenceResult<()> {
91 Ok(())
92 }
93
94 /// Local runtimes rebuild after CUDA context poison; remote
95 /// runtimes rebuild after auth failure or config change.
96 async fn rebuild_session(&mut self, cause: SessionRebuildCause) -> InferenceResult<()>;
97
98 fn runtime_kind(&self) -> RuntimeKind;
99 fn transport_kind(&self) -> TransportKind;
100 fn gil_pinned(&self) -> bool {
101 matches!(self.runtime_kind(), RuntimeKind::Vllm | RuntimeKind::Python(_))
102 }
103
104 /// Rate-limit metadata. Returns `None` for local runtimes; remote
105 /// runtimes return their configured limits so the
106 /// `RateLimiterActor` can be initialized at deploy time.
107 fn rate_limits(&self) -> Option<&RateLimits> {
108 None
109 }
110
111 /// Best-effort cost estimate for the given batch (USD). Used by
112 /// `TieredRouter`-style actors and budget enforcement. Local
113 /// runtimes default to 0 (compute cost is amortized).
114 fn estimate_cost_usd(&self, _batch: &ExecuteBatch) -> f64 {
115 0.0
116 }
117}
118
119/// Helper: convert a generic error string to an `InferenceError`. Useful
120/// inside `RunHandle` stream futures that need to lift unrelated errors.
121pub fn lift_internal<E: std::fmt::Display>(err: E) -> InferenceError {
122 InferenceError::Internal(err.to_string())
123}