Skip to main content

inference_core/
batch.rs

1//! Request batch โ€” what the runtime executes.
2//!
3//! `ExecuteBatch` is intentionally small: it carries one logical
4//! request's worth of input. Local runtimes that batch internally (vLLM,
5//! TensorRT) batch *across* `ExecuteBatch` instances inside their own
6//! engine module โ€” see doc ยง5.2 ("scheduler and batching are modules,
7//! not actors").
8
9use serde::{Deserialize, Serialize};
10
11/// One conversation message. OpenAI-compatible shape so the gateway can
12/// pass it through with minimal translation; provider-specific
13/// runtimes (Anthropic, Gemini) translate at the edge.
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct Message {
16    pub role: Role,
17    pub content: MessageContent,
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
21#[serde(rename_all = "lowercase")]
22pub enum Role {
23    System,
24    User,
25    Assistant,
26    Tool,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30#[serde(untagged)]
31pub enum MessageContent {
32    Text(String),
33    Parts(Vec<ContentPart>),
34}
35
36#[derive(Debug, Clone, Serialize, Deserialize)]
37#[serde(tag = "type", rename_all = "snake_case")]
38pub enum ContentPart {
39    Text {
40        text: String,
41    },
42    /// Base64-encoded image input. Provider runtimes translate to their
43    /// preferred wire format.
44    ImageBase64 {
45        mime: String,
46        data: String,
47    },
48    /// URL-referenced image (provider-supported only).
49    ImageUrl {
50        url: String,
51    },
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize, Default)]
55pub struct SamplingParams {
56    pub temperature: Option<f32>,
57    pub top_p: Option<f32>,
58    pub top_k: Option<u32>,
59    pub max_tokens: Option<u32>,
60    pub stop: Vec<String>,
61    pub presence_penalty: Option<f32>,
62    pub frequency_penalty: Option<f32>,
63    pub seed: Option<u64>,
64}
65
66/// One unit of work handed to a `ModelRunner`. `request_id` is the
67/// `RequestActor`'s identifier so completions can be correlated back.
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct ExecuteBatch {
70    pub request_id: String,
71    pub model: String,
72    pub messages: Vec<Message>,
73    pub sampling: SamplingParams,
74    /// True if the caller wants token-by-token streaming (`Tokens`
75    /// chunks). False if a single final `Tokens` is acceptable.
76    pub stream: bool,
77    /// Best-effort estimate of input + max_output tokens, used by
78    /// `RateLimiterActor` to acquire a TPM permit before the request
79    /// hits the wire.
80    pub estimated_tokens: u32,
81}
82
83impl ExecuteBatch {
84    pub fn estimated_tokens(&self) -> u32 {
85        self.estimated_tokens.max(1)
86    }
87}