Skip to main content

ferrum_types/
requests.rs

1//! Request and response types for inference
2
3use crate::{ids::*, models::TokenUsage, FinishReason, Priority, SamplingParams, TokenId};
4use chrono::{DateTime, Utc};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7
8/// Inference request
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct InferenceRequest {
11    /// Unique request identifier
12    pub id: RequestId,
13    /// Input prompt text
14    pub prompt: String,
15    /// Model to use for inference
16    pub model_id: ModelId,
17    /// Sampling parameters
18    pub sampling_params: SamplingParams,
19    /// Whether to stream response
20    pub stream: bool,
21    /// Request priority
22    pub priority: Priority,
23    /// Client identifier
24    pub client_id: Option<ClientId>,
25    /// Session identifier for stateful interactions
26    pub session_id: Option<SessionId>,
27    /// Request creation timestamp
28    pub created_at: DateTime<Utc>,
29    /// Additional metadata
30    pub metadata: HashMap<String, serde_json::Value>,
31}
32
33impl InferenceRequest {
34    /// Create a new inference request
35    pub fn new(prompt: impl Into<String>, model_id: impl Into<ModelId>) -> Self {
36        Self {
37            id: RequestId::new(),
38            prompt: prompt.into(),
39            model_id: model_id.into(),
40            sampling_params: SamplingParams::default(),
41            stream: false,
42            priority: Priority::default(),
43            client_id: None,
44            session_id: None,
45            created_at: Utc::now(),
46            metadata: HashMap::new(),
47        }
48    }
49
50    /// Set sampling parameters
51    pub fn with_sampling_params(mut self, params: SamplingParams) -> Self {
52        self.sampling_params = params;
53        self
54    }
55
56    /// Enable streaming
57    pub fn with_stream(mut self, stream: bool) -> Self {
58        self.stream = stream;
59        self
60    }
61
62    /// Set priority
63    pub fn with_priority(mut self, priority: Priority) -> Self {
64        self.priority = priority;
65        self
66    }
67
68    /// Set client ID
69    pub fn with_client_id(mut self, client_id: impl Into<ClientId>) -> Self {
70        self.client_id = Some(client_id.into());
71        self
72    }
73
74    /// Set session ID
75    pub fn with_session_id(mut self, session_id: SessionId) -> Self {
76        self.session_id = Some(session_id);
77        self
78    }
79
80    /// Add metadata
81    pub fn with_metadata(mut self, key: impl Into<String>, value: serde_json::Value) -> Self {
82        self.metadata.insert(key.into(), value);
83        self
84    }
85}
86
87/// Inference response
88#[derive(Debug, Clone, Serialize, Deserialize)]
89pub struct InferenceResponse {
90    /// Request ID this response corresponds to
91    pub request_id: RequestId,
92    /// Generated text
93    pub text: String,
94    /// Generated token IDs
95    pub tokens: Vec<TokenId>,
96    /// Reason for completion
97    pub finish_reason: FinishReason,
98    /// Token usage statistics
99    pub usage: TokenUsage,
100    /// Total latency in milliseconds
101    pub latency_ms: u64,
102    /// Response creation timestamp
103    pub created_at: DateTime<Utc>,
104    /// Additional response metadata
105    pub metadata: HashMap<String, serde_json::Value>,
106}
107
108/// Streaming response chunk
109#[derive(Debug, Clone, Serialize, Deserialize)]
110pub struct StreamChunk {
111    /// Request ID this chunk corresponds to
112    pub request_id: RequestId,
113    /// Text delta for this chunk
114    pub text: String,
115    /// Token ID for this chunk (if available)
116    pub token: Option<TokenId>,
117    /// Finish reason if this is the final chunk
118    pub finish_reason: Option<FinishReason>,
119    /// Token usage (typically only in final chunk)
120    pub usage: Option<TokenUsage>,
121    /// Chunk creation timestamp
122    pub created_at: DateTime<Utc>,
123    /// Chunk metadata
124    pub metadata: HashMap<String, serde_json::Value>,
125}
126
127/// Batch request for processing multiple requests together
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct BatchRequest {
130    /// Batch identifier
131    pub batch_id: BatchId,
132    /// Requests in this batch
133    pub requests: Vec<InferenceRequest>,
134    /// Maximum sequence length for this batch
135    pub max_sequence_length: usize,
136    /// Batch creation timestamp
137    pub created_at: DateTime<Utc>,
138}
139
140impl BatchRequest {
141    /// Create a new batch request
142    pub fn new(requests: Vec<InferenceRequest>) -> Self {
143        let max_sequence_length = requests
144            .iter()
145            .map(|r| r.sampling_params.max_tokens)
146            .max()
147            .unwrap_or(512);
148
149        Self {
150            batch_id: BatchId::new(),
151            requests,
152            max_sequence_length,
153            created_at: Utc::now(),
154        }
155    }
156
157    /// Get the number of requests in this batch
158    pub fn size(&self) -> usize {
159        self.requests.len()
160    }
161
162    /// Check if batch is empty
163    pub fn is_empty(&self) -> bool {
164        self.requests.is_empty()
165    }
166}
167
168/// Request state in the scheduler
169#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
170pub enum RequestState {
171    /// Request is waiting in queue
172    Waiting,
173    /// Request is being processed
174    Running,
175    /// Request was preempted and is waiting to resume
176    Preempted,
177    /// Request completed successfully
178    Completed,
179    /// Request failed with error
180    Failed,
181    /// Request was cancelled
182    Cancelled,
183}
184
185/// Scheduled request with additional state information
186#[derive(Debug, Clone)]
187pub struct ScheduledRequest {
188    /// The original request
189    pub request: InferenceRequest,
190    /// Current state in scheduler
191    pub state: RequestState,
192    /// Allocated cache blocks
193    pub allocated_blocks: Vec<crate::BlockId>,
194    /// Number of tokens processed so far
195    pub tokens_processed: usize,
196    /// Estimated completion time
197    pub estimated_completion: Option<DateTime<Utc>>,
198}
199
200impl ScheduledRequest {
201    /// Create a new scheduled request
202    pub fn new(request: InferenceRequest) -> Self {
203        Self {
204            request,
205            state: RequestState::Waiting,
206            allocated_blocks: Vec::new(),
207            tokens_processed: 0,
208            estimated_completion: None,
209        }
210    }
211
212    /// Update request state
213    pub fn set_state(&mut self, state: RequestState) {
214        self.state = state;
215    }
216
217    /// Add allocated cache blocks
218    pub fn add_blocks(&mut self, blocks: Vec<crate::BlockId>) {
219        self.allocated_blocks.extend(blocks);
220    }
221
222    /// Update tokens processed
223    pub fn update_progress(&mut self, tokens_processed: usize) {
224        self.tokens_processed = tokens_processed;
225    }
226}