1mod ticket;
6
7use crate::{error::CervoError, state::ModelState, AgentId, BrainId};
8use ticket::Ticket;
9
10use cervo_core::prelude::{Inferer, Response, State};
11#[cfg(feature = "threaded")]
12use rayon::iter::IntoParallelIterator;
13#[cfg(feature = "threaded")]
14use rayon::iter::IntoParallelRefMutIterator;
15#[cfg(feature = "threaded")]
16use rayon::iter::ParallelIterator;
17use std::time::Instant;
18use std::{
19 collections::{BinaryHeap, HashMap},
20 time::Duration,
21};
22
23pub struct Runtime {
25 models: Vec<ModelState>,
26 queue: BinaryHeap<Ticket>,
27 ticket_generation: u64,
28 brain_generation: u16,
29}
30
31impl Default for Runtime {
32 fn default() -> Self {
33 Self::new()
34 }
35}
36
37impl Runtime {
38 pub fn new() -> Self {
40 Self {
41 models: Vec::with_capacity(16),
42 queue: BinaryHeap::with_capacity(16),
43 ticket_generation: 0,
44 brain_generation: 0,
45 }
46 }
47
48 pub fn add_inferer(&mut self, inferer: impl Inferer + 'static + Send) -> BrainId {
50 let id = BrainId(self.brain_generation);
51 self.brain_generation += 1;
52
53 self.models.push(ModelState::new(id, inferer));
54
55 self.queue.push(Ticket(self.ticket_generation, id));
57 self.ticket_generation += 1;
58
59 id
60 }
61
62 pub fn push(
64 &mut self,
65 brain: BrainId,
66 agent: AgentId,
67 state: State<'_>,
68 ) -> Result<(), CervoError> {
69 match self.models.iter_mut().find(|m| m.id == brain) {
70 Some(model) => model.push(agent, state),
71 None => Err(CervoError::UnknownBrain(brain)),
72 }
73 }
74
75 pub fn infer_single(
79 &mut self,
80 brain_id: BrainId,
81 state: State<'_>,
82 ) -> Result<Response<'_>, CervoError> {
83 match self.models.iter_mut().find(|m| m.id == brain_id) {
84 Some(model) => model.infer_single(state),
85 None => Err(CervoError::UnknownBrain(brain_id)),
86 }
87 }
88
89 #[cfg(feature = "threaded")]
91 pub fn run_threaded(&mut self) -> HashMap<BrainId, HashMap<AgentId, Response<'_>>> {
92 self.models
94 .par_iter_mut()
95 .filter(|model| model.needs_to_execute())
96 .map(|model| (model.id, model.run().unwrap()))
97 .collect::<HashMap<BrainId, HashMap<AgentId, Response<'_>>>>()
98 }
99
100 pub fn run(&mut self) -> Result<HashMap<BrainId, HashMap<AgentId, Response<'_>>>, CervoError> {
102 let mut result = HashMap::default();
103
104 for model in self.models.iter_mut() {
105 if !model.needs_to_execute() {
106 continue;
107 }
108
109 result.insert(model.id, model.run()?);
110 }
111
112 Ok(result)
113 }
114
115 #[cfg(feature = "threaded")]
119 pub fn run_for_threaded(
120 &mut self,
121 duration: Duration,
122 ) -> Result<HashMap<BrainId, HashMap<AgentId, Response<'_>>>, CervoError> {
123 let mut available_cpu_time = duration * rayon::current_num_threads() as u32;
124 let mut selected_jobs = Vec::new();
125 let mut unselected_jobs = Vec::new();
126
127 while let Some(ticket) = self.queue.pop() {
128 let Some(model) = self.models.iter().find(|m| m.id == ticket.1) else {
129 continue;
130 };
131
132 if model.needs_to_execute()
133 && (selected_jobs.is_empty() || model.can_run_in_time(available_cpu_time))
134 {
135 available_cpu_time = available_cpu_time.saturating_sub(model.estimated_time());
136 selected_jobs.push((ticket, model));
137 } else {
138 unselected_jobs.push(ticket);
139 }
140 }
141
142 let results = selected_jobs
143 .into_par_iter()
144 .map(|(ticket, model)| (ticket.1, model.run()))
145 .collect::<Vec<(_, _)>>(); let new_tickets = results.iter().map(|(b, _)| {
148 let gen = self.ticket_generation;
149 self.ticket_generation += 1;
150 Ticket(gen, *b)
151 });
152
153 self.queue
154 .extend(unselected_jobs.into_iter().chain(new_tickets));
155
156 results
158 .into_iter()
159 .map(|(b, res)| res.map(|val| (b, val)))
160 .collect::<Result<_, _>>()
161 }
162
163 pub fn run_for(
167 &mut self,
168 mut duration: Duration,
169 ) -> Result<HashMap<BrainId, HashMap<AgentId, Response<'_>>>, CervoError> {
170 let mut result = HashMap::default();
171
172 let mut any_executed = false;
173 let mut executed: Vec<BrainId> = vec![];
174 let mut non_executed = vec![];
175
176 while !self.queue.is_empty() {
177 let ticket = self.queue.pop().unwrap();
178 let res = match self.models.iter().find(|m| m.id == ticket.1) {
179 Some(model) => {
180 if !model.needs_to_execute() || any_executed && !model.can_run_in_time(duration)
181 {
182 Ok(None)
183 } else {
184 let start = Instant::now();
185 let r = model.run();
186
187 let elapsed = start.elapsed();
188 duration = duration.saturating_sub(elapsed);
189
190 any_executed = true;
191 r.map(Some)
192 }
193 }
194
195 None => return Err(CervoError::UnknownBrain(ticket.1)),
196 }?;
197
198 match res {
199 Some(res) => {
200 result.insert(ticket.1, res);
201 executed.push(ticket.1);
202 }
203 None => {
204 non_executed.push(ticket);
205 }
206 }
207 }
208
209 self.queue.extend(non_executed);
210 for id in executed {
211 let gen = self.ticket_generation;
212 self.ticket_generation += 1;
213 self.queue.push(Ticket(gen, id));
214 }
215
216 Ok(result)
217 }
218
219 pub fn output_shapes(&self, brain: BrainId) -> Result<&[(String, Vec<usize>)], CervoError> {
221 match self.models.iter().find(|m| m.id == brain) {
222 Some(model) => Ok(model.inferer.output_shapes()),
223 None => Err(CervoError::UnknownBrain(brain)),
224 }
225 }
226
227 pub fn input_shapes(&self, brain: BrainId) -> Result<&[(String, Vec<usize>)], CervoError> {
229 match self.models.iter().find(|m| m.id == brain) {
230 Some(model) => Ok(model.inferer.input_shapes()),
231 None => Err(CervoError::UnknownBrain(brain)),
232 }
233 }
234
235 pub fn clear(&mut self) -> Result<(), CervoError> {
239 self.queue.clear();
241 self.ticket_generation = 0;
242
243 let mut has_data = vec![];
244 for model in self.models.drain(..) {
245 if model.needs_to_execute() {
246 has_data.push(model.id);
247 }
248 }
249
250 if !has_data.is_empty() {
251 Err(CervoError::OrphanedData(has_data))
252 } else {
253 Ok(())
254 }
255 }
256
257 pub fn remove_inferer(&mut self, brain: BrainId) -> Result<(), CervoError> {
260 let mut to_repush = vec![];
262 while !self.queue.is_empty() {
263 let elem = self.queue.pop().unwrap();
265
266 if elem.1 == brain {
267 break;
268 } else {
269 to_repush.push(elem);
270 }
271 }
272
273 self.queue.extend(to_repush);
274
275 if let Some(index) = self.models.iter().position(|state| state.id == brain) {
276 let state = self.models.remove(index);
278 if state.needs_to_execute() {
279 Err(CervoError::OrphanedData(vec![brain]))
280 } else {
281 Ok(())
282 }
283 } else {
284 Err(CervoError::UnknownBrain(brain))
285 }
286 }
287}
288
289#[cfg(test)]
290mod tests {
291 use super::Runtime;
292 use crate::{BrainId, CervoError};
293 use cervo_core::prelude::{Inferer, State};
294 use std::time::Duration;
295
296 struct DummyInferer {
297 sleep_duration: Duration,
298 }
299
300 impl Inferer for DummyInferer {
301 fn select_batch_size(&self, count: usize) -> usize {
302 assert_eq!(count, 1);
303 count
304 }
305
306 fn infer_raw(
307 &self,
308 _batch: &mut cervo_core::batcher::ScratchPadView<'_>,
309 ) -> anyhow::Result<(), anyhow::Error> {
310 std::thread::sleep(self.sleep_duration);
311 Ok(())
312 }
313
314 fn raw_input_shapes(&self) -> &[(String, Vec<usize>)] {
315 &[]
316 }
317
318 fn raw_output_shapes(&self) -> &[(String, Vec<usize>)] {
319 &[]
320 }
321
322 fn begin_agent(&self, _id: u64) {}
323 fn end_agent(&self, _id: u64) {}
324 }
325
326 #[test]
327 fn test_run_for_rotation() {
328 let mut runtime = Runtime::new();
329 let mut keys = vec![];
330 for sleep in [0.02, 0.04, 0.06, 0.04] {
331 keys.push(runtime.add_inferer(DummyInferer {
332 sleep_duration: Duration::from_secs_f32(sleep),
333 }));
334 }
335
336 let push = |runtime: &mut Runtime, keys: &[BrainId]| {
337 for k in keys {
338 runtime.push(*k, 0, State::empty()).unwrap();
339 }
340 };
341
342 for _ in 0..10 {
343 push(&mut runtime, &keys);
344 runtime.run().unwrap();
345 }
346
347 push(&mut runtime, &keys);
348 let res = runtime.run_for(Duration::from_secs_f32(0.07)).unwrap();
349 assert_eq!(res.len(), 2, "got keys: {:?}", res.keys());
350 assert!(res.contains_key(&keys[0]));
351 assert!(res.contains_key(&keys[1]));
352
353 let res = runtime.run_for(Duration::from_secs_f32(0.07)).unwrap();
354 assert_eq!(res.len(), 1);
355 assert!(res.contains_key(&keys[2]));
356
357 let res = runtime.run_for(Duration::from_secs_f32(0.07)).unwrap();
360 assert_eq!(res.len(), 1);
361 assert!(res.contains_key(&keys[3]));
362
363 push(&mut runtime, &keys);
364 let res = runtime.run_for(Duration::from_secs_f32(0.165)).unwrap();
365 assert_eq!(res.len(), 4, "got keys: {:?}", res.keys());
366 assert!(res.contains_key(&keys[0]));
367 assert!(res.contains_key(&keys[1]));
368 assert!(res.contains_key(&keys[2]));
369 assert!(res.contains_key(&keys[3]));
370 }
371
372 #[test]
373 fn test_run_skip_expensive() {
374 let mut runtime = Runtime::new();
375 let mut keys = vec![];
376 for sleep in [0.02, 0.04, 0.06, 0.04] {
377 keys.push(runtime.add_inferer(DummyInferer {
378 sleep_duration: Duration::from_secs_f32(sleep),
379 }));
380 }
381
382 let push = |runtime: &mut Runtime, keys: &[BrainId]| {
383 for k in keys {
384 runtime.push(*k, 0, State::empty()).unwrap();
385 }
386 };
387
388 for _ in 0..10 {
389 push(&mut runtime, &keys);
390 runtime.run().unwrap();
391 }
392
393 push(&mut runtime, &keys);
394 let res = runtime.run_for(Duration::from_secs_f32(0.11)).unwrap();
395 assert_eq!(res.len(), 3, "got keys: {:?}", res.keys());
396 assert!(res.contains_key(&keys[0]));
397 assert!(res.contains_key(&keys[1]));
398 assert!(res.contains_key(&keys[3]));
399 }
400
401 #[test]
402 fn test_run_for_greedy() {
403 let mut runtime = Runtime::new();
404 let mut keys = vec![];
405 for sleep in [0.02, 0.04, 0.06] {
406 keys.push(runtime.add_inferer(DummyInferer {
407 sleep_duration: Duration::from_secs_f32(sleep),
408 }));
409 }
410
411 let push = |runtime: &mut Runtime, keys: &[BrainId]| {
412 for k in keys {
413 runtime.push(*k, 0, State::empty()).unwrap();
414 }
415 };
416
417 for _ in 0..10 {
418 push(&mut runtime, &keys);
419 runtime.run().unwrap();
420 }
421
422 push(&mut runtime, &keys);
423 let res = runtime.run_for(Duration::from_secs_f32(0.0)).unwrap();
424 assert_eq!(res.len(), 1, "got keys: {:?}", res.keys());
425 assert!(res.contains_key(&keys[0]));
426
427 let res = runtime.run_for(Duration::from_secs_f32(0.0)).unwrap();
429 assert_eq!(res.len(), 1);
430 assert!(res.contains_key(&keys[1]));
431
432 let res = runtime.run_for(Duration::from_secs_f32(0.0)).unwrap();
434 assert_eq!(res.len(), 1);
435 assert!(res.contains_key(&keys[2]));
436 }
437
438 #[test]
439 fn test_run_single() {
440 let mut runtime = Runtime::new();
441
442 let k = runtime.add_inferer(DummyInferer {
443 sleep_duration: Duration::from_secs_f32(0.01),
444 });
445
446 runtime.infer_single(k, State::empty()).unwrap();
447 let r = runtime.run().unwrap();
448 assert_eq!(r.len(), 0);
449 }
450
451 #[test]
452 fn test_run_single_with_push() {
453 let mut runtime = Runtime::new();
454
455 let k = runtime.add_inferer(DummyInferer {
456 sleep_duration: Duration::from_secs_f32(0.01),
457 });
458
459 runtime.push(k, 0, State::empty()).unwrap();
460
461 runtime.infer_single(k, State::empty()).unwrap();
462 let mut r = runtime.run().unwrap();
463 assert_eq!(r.len(), 1);
464 let data = r.remove(&k).unwrap();
465
466 assert_eq!(data.len(), 1);
467 assert!(data.contains_key(&0));
468 }
469
470 #[test]
471 fn unknown_brain_push() {
472 let mut runtime = Runtime::new();
473 let res = runtime.push(BrainId(10), 0, State::empty());
474
475 assert!(res.is_err());
476 let err = res.unwrap_err();
477
478 if let CervoError::UnknownBrain(BrainId(10)) = err {
479 } else {
480 panic!("expected CervoError::UnknownBrain")
481 }
482 }
483
484 #[test]
485 fn unknown_brain_infer_single() {
486 let mut runtime = Runtime::new();
487 let res = runtime.infer_single(BrainId(10), State::empty());
488
489 assert!(res.is_err());
490 let err = res.unwrap_err();
491
492 if let CervoError::UnknownBrain(BrainId(10)) = err {
493 } else {
494 panic!("expected CervoError::UnknownBrain")
495 }
496 }
497
498 #[test]
499 fn unknown_brain_remove() {
500 let mut runtime = Runtime::new();
501 let res = runtime.remove_inferer(BrainId(10));
502
503 assert!(res.is_err());
504 let err = res.unwrap_err();
505
506 if let CervoError::UnknownBrain(BrainId(10)) = err {
507 } else {
508 panic!("expected CervoError::UnknownBrain")
509 }
510 }
511
512 #[test]
513 fn unknown_brain_remove_orphaned() {
514 let mut runtime = Runtime::new();
515 let k = runtime.add_inferer(DummyInferer {
516 sleep_duration: Duration::from_secs_f32(0.1),
517 });
518 runtime.push(k, 0, State::empty()).unwrap();
519 let res = runtime.remove_inferer(k);
520
521 assert!(res.is_err());
522 let err = res.unwrap_err();
523
524 if let CervoError::OrphanedData(keys) = err {
525 assert_eq!(keys, vec![k]);
526 } else {
527 panic!("expected CervoError::OrphanedData")
528 }
529 }
530
531 #[test]
532 fn unknown_brain_clear_orphaned() {
533 let mut runtime = Runtime::new();
534 let k = runtime.add_inferer(DummyInferer {
535 sleep_duration: Duration::from_secs_f32(0.1),
536 });
537 runtime.push(k, 0, State::empty()).unwrap();
538 let res = runtime.clear();
539
540 assert!(res.is_err());
541 let err = res.unwrap_err();
542
543 if let CervoError::OrphanedData(keys) = err {
544 assert_eq!(keys, vec![k]);
545 } else {
546 panic!("expected CervoError::OrphanedData")
547 }
548 }
549}