1use crate::{
2 distributed,
3 paged_attention::block_hash::compute_block_hashes,
4 pipeline::{
5 llg::{constraint_from_llg_grammar, llg_grammar_from_constraint},
6 text_models_inputs_processor::PagedAttentionMeta,
7 CacheBackendMetadata, CacheInstruction,
8 },
9 prefix_cacher::PrefixCacheManagerV2,
10 response::CompletionChoice,
11 scheduler::{Scheduler, SchedulerOutput},
12 search::{self, rag::SearchPipeline},
13 sequence::{SeqStepType, StopReason},
14 tools, CompletionResponse, SchedulerConfig, DEBUG,
15};
16use hanzo_quant::RingConfig;
17use interprocess::local_socket::{traits::Listener, ListenerOptions};
18use llguidance::ParserFactory;
19pub use logger::IntervalLogger;
20use rand::SeedableRng;
21use rand_isaac::Isaac64Rng;
22use serde::{Deserialize, Serialize};
23use std::{
24 collections::HashMap,
25 fmt,
26 io::{BufWriter, Write},
27 net::TcpListener,
28 ops::Deref,
29 str::FromStr,
30 sync::{
31 atomic::{AtomicBool, Ordering},
32 Arc, LazyLock,
33 },
34 time::{Duration, Instant, SystemTime, UNIX_EPOCH},
35};
36use tokio::{
37 select,
38 sync::{
39 mpsc::{error::TryRecvError, Receiver, Sender},
40 Mutex, Notify,
41 },
42 task::JoinHandle,
43};
44
45use crate::{
46 get_mut_arcmutex, handle_pipeline_forward_error,
47 pipeline::{ModelCategory, Pipeline},
48 request::Request,
49 response::{ChatCompletionResponse, Choice, ResponseMessage},
50 sequence::{SequenceRecognizer, SequenceState},
51 Constraint,
52};
53
54mod add_request;
55pub(crate) mod agentic_loop;
56pub use agentic_loop::DEFAULT_MAX_TOOL_ROUNDS;
57pub(crate) mod agentic_session;
58mod file_tools;
59mod logger;
60mod tool_dispatch;
61
62pub enum EngineInstruction {
63 Terminate,
64}
65
66#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
67#[serde(rename_all = "snake_case")]
68pub enum SearchEmbeddingModel {
70 #[default]
71 #[serde(rename = "embedding_gemma")]
72 EmbeddingGemma300M,
73}
74
75impl SearchEmbeddingModel {
76 pub fn hf_model_id(&self) -> &'static str {
77 match self {
78 Self::EmbeddingGemma300M => "google/embeddinggemma-300m",
79 }
80 }
81
82 pub fn variants() -> &'static [&'static str] {
83 &["embedding_gemma"]
84 }
85}
86
87impl fmt::Display for SearchEmbeddingModel {
88 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
89 match self {
90 Self::EmbeddingGemma300M => f.write_str("embedding_gemma"),
91 }
92 }
93}
94
95impl FromStr for SearchEmbeddingModel {
96 type Err = String;
97
98 fn from_str(s: &str) -> Result<Self, Self::Err> {
99 match s.trim().to_ascii_lowercase().as_str() {
100 "embedding_gemma" => Ok(Self::EmbeddingGemma300M),
101 other => Err(format!(
102 "Unknown search embedding model `{other}`. Supported values: {}",
103 Self::variants().join(", ")
104 )),
105 }
106 }
107}
108
109const SEED: u64 = 0;
110pub static TERMINATE_ALL_NEXT_STEP: AtomicBool = AtomicBool::new(false);
113
114static ENGINE_TERMINATE_FLAGS: LazyLock<
116 std::sync::Mutex<HashMap<std::thread::ThreadId, Arc<AtomicBool>>>,
117> = LazyLock::new(|| std::sync::Mutex::new(HashMap::new()));
118
119pub fn get_engine_terminate_flag() -> Arc<AtomicBool> {
121 let thread_id = std::thread::current().id();
122 let mut flags = ENGINE_TERMINATE_FLAGS.lock().unwrap();
123 flags
124 .entry(thread_id)
125 .or_insert_with(|| Arc::new(AtomicBool::new(false)))
126 .clone()
127}
128
129pub fn should_terminate_engine_sequences() -> bool {
131 if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
133 return true;
134 }
135 let thread_id = std::thread::current().id();
137 if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
138 if let Some(flag) = flags.get(&thread_id) {
139 return flag.load(Ordering::SeqCst);
140 }
141 }
142 false
143}
144
145pub fn reset_engine_terminate_flag() {
147 let thread_id = std::thread::current().id();
148 if let Ok(flags) = ENGINE_TERMINATE_FLAGS.lock() {
149 if let Some(flag) = flags.get(&thread_id) {
150 flag.store(false, Ordering::SeqCst);
151 }
152 }
153}
154
155pub static ENGINE_INSTRUCTIONS: LazyLock<
157 std::sync::Mutex<HashMap<usize, Option<EngineInstruction>>>,
158> = LazyLock::new(|| std::sync::Mutex::new(HashMap::new()));
159
160pub struct Engine {
161 tx: Sender<Request>,
162 rx: Arc<Mutex<Receiver<Request>>>,
163 pipeline: Arc<Mutex<dyn Pipeline>>,
164 search_pipeline: Arc<Mutex<Option<SearchPipeline>>>,
165 search_callback: Option<Arc<search::SearchCallback>>,
166 tool_callbacks: tools::ToolCallbacksWithTools,
167 scheduler: Arc<Mutex<dyn Scheduler>>,
168 id: Arc<Mutex<usize>>,
169 no_kv_cache: bool,
170 prefix_cacher: Arc<Mutex<PrefixCacheManagerV2>>,
171 is_debug: bool,
172 disable_eos_stop: bool,
173 throughput_logging_enabled: bool,
174 logger: Arc<IntervalLogger>,
175 handles: Arc<Mutex<Vec<JoinHandle<()>>>>,
176 pending_notify: Arc<Notify>,
177 pub(crate) session_store: Arc<std::sync::Mutex<agentic_session::AgenticSessionStore>>,
178 pub(crate) file_store: crate::files::FileStore,
179}
180
181impl Drop for Engine {
182 fn drop(&mut self) {
183 for handle in &*get_mut_arcmutex!(self.handles) {
184 handle.abort();
185 }
186 }
187}
188
189impl Engine {
190 #[allow(clippy::too_many_arguments)]
191 pub fn new(
192 tx: Sender<Request>,
193 rx: Receiver<Request>,
194 pipeline: Arc<Mutex<dyn Pipeline>>,
195 config: SchedulerConfig,
196 mut no_kv_cache: bool,
197 mut no_prefix_cache: bool,
198 prefix_cache_n: usize,
199 disable_eos_stop: bool,
200 throughput_logging_enabled: bool,
201 search_embedding_model: Option<SearchEmbeddingModel>,
202 search_callback: Option<Arc<search::SearchCallback>>,
203 tool_callbacks: tools::ToolCallbacksWithTools,
204 logger: Arc<IntervalLogger>,
205 session_store: Arc<std::sync::Mutex<agentic_session::AgenticSessionStore>>,
206 file_store: crate::files::FileStore,
207 ) -> anyhow::Result<Self> {
208 no_kv_cache |= get_mut_arcmutex!(pipeline).get_metadata().no_kv_cache;
209
210 no_prefix_cache = no_prefix_cache
211 || no_kv_cache
212 || get_mut_arcmutex!(pipeline).get_metadata().no_prefix_cache
213 || prefix_cache_n == 0;
214
215 let search_pipeline = match search_embedding_model {
216 Some(search_embedding_model) => Some(SearchPipeline::new(
217 search_embedding_model,
218 &get_mut_arcmutex!(pipeline).device(),
219 )?),
220 None => None,
221 };
222
223 let scheduler = config.into_scheduler();
224
225 get_mut_arcmutex!(scheduler).set_prefix_caching_enabled(!no_prefix_cache);
228
229 let has_paged_attention = get_mut_arcmutex!(scheduler).kv_cache_manager().is_some();
230
231 Ok(Self {
232 tx,
233 rx: Arc::new(Mutex::new(rx)),
234 pipeline,
235 search_pipeline: Arc::new(Mutex::new(search_pipeline)),
236 search_callback,
237 tool_callbacks,
238 scheduler: scheduler.clone(),
239 id: Arc::new(Mutex::new(0)),
240 no_kv_cache,
241 prefix_cacher: Arc::new(Mutex::new(PrefixCacheManagerV2::new(
242 prefix_cache_n,
243 no_prefix_cache,
244 has_paged_attention,
245 ))),
246 is_debug: DEBUG.load(Ordering::Relaxed),
247 disable_eos_stop,
248 throughput_logging_enabled,
249 logger,
250 handles: Arc::new(Mutex::new(Vec::new())),
251 pending_notify: Arc::new(Notify::new()),
252 session_store,
253 file_store,
254 })
255 }
256
257 #[allow(dead_code)]
259 pub fn max_sequence_length(&self) -> Option<usize> {
260 let pipeline = get_mut_arcmutex!(self.pipeline);
261 let category = pipeline.category();
262
263 if matches!(category, ModelCategory::Diffusion | ModelCategory::Speech) {
264 None
265 } else {
266 Some(pipeline.get_metadata().max_seq_len)
267 }
268 }
269
270 pub async fn run(self: Arc<Self>) {
271 if self.throughput_logging_enabled {
272 self.logger.enable_logging();
273 }
274
275 let rng = Arc::new(std::sync::Mutex::new(Isaac64Rng::seed_from_u64(SEED)));
276 let mut last_completion_ids: Vec<usize> = vec![];
277 'lp: loop {
278 let should_terminate = || {
279 matches!(
280 ENGINE_INSTRUCTIONS
281 .lock()
282 .expect("`ENGINE_INSTRUCTIONS` was poisoned")
283 .get(get_mut_arcmutex!(self.id).deref()),
284 Some(Some(EngineInstruction::Terminate))
285 )
286 };
287
288 if should_terminate() {
289 self.replicate_request_to_daemons(&Request::Terminate);
290 break 'lp;
291 }
292
293 let mut channel_disconnected = false;
294 loop {
295 let next_request = {
296 let mut rx = self.rx.lock().await;
297 rx.try_recv()
298 };
299
300 match next_request {
301 Ok(request) => {
302 self.replicate_request_to_daemons(&request);
303 if matches!(request, Request::Terminate) {
304 break 'lp;
305 }
306 self.clone().handle_request(request).await;
307 }
308 Err(TryRecvError::Empty) => break,
309 Err(TryRecvError::Disconnected) => {
310 channel_disconnected = true;
311 break;
312 }
313 }
314 }
315
316 if channel_disconnected {
317 break 'lp;
318 }
319
320 let (waiting_len, running_len) = {
321 let scheduler = get_mut_arcmutex!(self.scheduler);
322 (scheduler.waiting_len(), scheduler.running_len())
323 };
324 let scheduler_idle = waiting_len == 0 && running_len == 0;
325
326 if scheduler_idle {
327 if should_terminate() {
328 self.replicate_request_to_daemons(&Request::Terminate);
329 break 'lp;
330 }
331 enum WaitEvent {
332 Request(Option<Request>),
333 Wake,
334 }
335 let wait_for_request = async {
336 let mut rx = self.rx.lock().await;
337 rx.recv().await
338 };
339 tokio::pin!(wait_for_request);
340 let wait_for_wake = self.pending_notify.notified();
341 tokio::pin!(wait_for_wake);
342
343 let event = select! {
344 res = &mut wait_for_request => WaitEvent::Request(res),
345 _ = &mut wait_for_wake => WaitEvent::Wake,
346 };
347
348 match event {
349 WaitEvent::Request(Some(request)) => {
350 self.replicate_request_to_daemons(&request);
351 if matches!(request, Request::Terminate) {
352 break 'lp;
353 }
354 self.clone().handle_request(request).await;
355 continue;
356 }
357 WaitEvent::Request(None) => break 'lp,
358 WaitEvent::Wake => {
359 continue;
360 }
361 }
362 }
363
364 if TERMINATE_ALL_NEXT_STEP.load(Ordering::SeqCst) {
365 self.replicate_request_to_daemons(&Request::TerminateAllSeqsNextStep);
366 }
367
368 let run_start = Instant::now();
369 let mut scheduler = get_mut_arcmutex!(self.scheduler);
370 let scheduled = scheduler.schedule(&self.logger);
371
372 match scheduled {
373 SchedulerOutput::DefaultScheduler {
374 output: mut scheduled,
375 } => {
376 if !scheduled.completion.is_empty() {
377 let current_completion_ids: Vec<usize> =
378 scheduled.completion.iter().map(|seq| *seq.id()).collect();
379 let res = {
380 let mut pipeline = get_mut_arcmutex!(self.pipeline);
381 let pre_op = if !self.no_kv_cache
382 && last_completion_ids != current_completion_ids
383 {
384 CacheInstruction::In
385 } else {
386 CacheInstruction::Nothing
387 };
388 let post_op = if !self.no_kv_cache {
389 CacheInstruction::Out
390 } else {
391 CacheInstruction::Reset {
392 load_preallocated_cache: false,
393 reset_non_granular: false,
394 }
395 };
396
397 let return_raw_logits = scheduled.completion[0].return_raw_logits;
398 assert!(
399 scheduled
400 .completion
401 .iter()
402 .all(|seq| seq.return_raw_logits == return_raw_logits),
403 "All sequences must either return raw logits, or not."
404 );
405
406 pipeline
407 .step(
408 &mut scheduled.completion,
409 false,
410 return_raw_logits,
411 &mut *get_mut_arcmutex!(self.prefix_cacher),
412 self.disable_eos_stop,
413 rng.clone(),
414 CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
415 )
416 .await
417 };
418
419 handle_pipeline_forward_error!(
420 "completion step",
421 res,
422 &mut scheduled.completion,
423 self.pipeline,
424 'lp,
425 self.prefix_cacher
426 );
427
428 self.logger.add_tokens_processed(scheduled.completion.len());
429
430 last_completion_ids = current_completion_ids;
431 }
432
433 if !scheduled.prompt.is_empty() {
434 let pre_step_now = SystemTime::now()
437 .duration_since(UNIX_EPOCH)
438 .expect("Time travel has occurred!")
439 .as_millis();
440 for seq in scheduled.prompt.iter_mut() {
441 seq.prompt_timestamp = Some(pre_step_now);
442 seq.set_step_start_instant();
443 }
444
445 let prompt_exec_time = {
446 let mut pipeline = get_mut_arcmutex!(self.pipeline);
447
448 let post_op = if !self.no_kv_cache {
450 CacheInstruction::Out
451 } else {
452 CacheInstruction::Reset {
453 load_preallocated_cache: false,
454 reset_non_granular: false,
455 }
456 };
457
458 let return_raw_logits = scheduled.prompt[0].return_raw_logits;
459 assert!(
460 scheduled
461 .prompt
462 .iter()
463 .all(|seq| seq.return_raw_logits == return_raw_logits),
464 "All sequences must either return raw logits, or not."
465 );
466
467 let pre_op = if scheduled.prompt[0].token_offset() != 0 {
470 CacheInstruction::In
471 } else {
472 CacheInstruction::Reset {
473 load_preallocated_cache: true,
474 reset_non_granular: false,
475 }
476 };
477
478 pipeline
479 .step(
480 &mut scheduled.prompt,
481 true,
482 return_raw_logits,
483 &mut *get_mut_arcmutex!(self.prefix_cacher),
484 self.disable_eos_stop,
485 rng.clone(),
486 CacheBackendMetadata::DefaultInstructions { pre_op, post_op },
487 )
488 .await
489 };
490
491 let prompt_exec_time = handle_pipeline_forward_error!(
492 "prompt step",
493 prompt_exec_time,
494 &mut scheduled.prompt,
495 self.pipeline,
496 'lp,
497 self.prefix_cacher
498 );
499
500 let total_processed_tokens: usize = scheduled
501 .prompt
502 .iter()
503 .map(|seq| seq.get_toks().len())
504 .sum();
505 self.logger.add_tokens_processed(total_processed_tokens);
506
507 for seq in scheduled.prompt.iter_mut() {
508 match seq.sequence_stepping_type() {
509 SeqStepType::OneShot => {
510 seq.set_state(SequenceState::Done(StopReason::GeneratedImage))
511 }
512 SeqStepType::PromptAndDecode => {
513 seq.set_state(SequenceState::RunningCompletion)
514 }
515 }
516 let now = SystemTime::now()
517 .duration_since(UNIX_EPOCH)
518 .expect("Time travel has occurred!")
519 .as_millis();
520 #[allow(clippy::cast_precision_loss)]
521 let prompt_tok_per_sec =
522 seq.len() as f32 / prompt_exec_time.as_secs_f32();
523 seq.prompt_tok_per_sec = prompt_tok_per_sec;
524 seq.prompt_timestamp = Some(now);
525 seq.total_prompt_time = Some(prompt_exec_time.as_millis());
526 seq.step_start_instant = None;
527 }
528 last_completion_ids = vec![];
529 }
530
531 if self.is_debug {
532 let ms_from_last_run = run_start.elapsed().as_secs_f64();
533 let total_len = scheduled.prompt.len() + scheduled.completion.len();
534 if total_len > 0 {
535 let prompt_lengths = scheduled
536 .prompt
537 .iter()
538 .map(|seq| seq.len().to_string())
539 .collect::<Vec<_>>()
540 .join(", ");
541
542 let completion_lengths = scheduled
543 .completion
544 .iter()
545 .map(|seq| seq.len().to_string())
546 .collect::<Vec<_>>()
547 .join(", ");
548
549 tracing::info!(
550 "Prompt[{}] Completion[{}] - {}ms",
551 prompt_lengths,
552 completion_lengths,
553 ms_from_last_run * 1000.,
554 );
555 }
556 }
557 }
558 SchedulerOutput::PagedAttention { mut output } => {
559 if !output.scheduled.is_empty() {
560 let is_prompt = get_mut_arcmutex!(output.scheduled[0]).is_prompt();
561
562 if is_prompt {
564 let now = SystemTime::now()
565 .duration_since(UNIX_EPOCH)
566 .expect("Time travel has occurred!")
567 .as_millis();
568 for seq in output.scheduled.iter() {
569 let mut seq_guard = get_mut_arcmutex!(seq);
570 seq_guard.prompt_timestamp = Some(now);
571 seq_guard.set_step_start_instant();
573 }
574 }
575
576 let mut guards = output
577 .scheduled
578 .iter_mut()
579 .map(|seq| seq.lock().unwrap())
580 .collect::<Vec<_>>();
581
582 let mut guards_mut =
583 guards.iter_mut().map(|seq| &mut **seq).collect::<Vec<_>>();
584
585 let res = {
586 let mut pipeline = get_mut_arcmutex!(self.pipeline);
587
588 let block_size = scheduler.block_size().unwrap();
589
590 if is_prompt && pipeline.cache().is_hybrid() {
593 let mut hybrid_cache = pipeline.cache().hybrid();
594 let mut prefix_cacher = get_mut_arcmutex!(self.prefix_cacher);
595 let kv_cache_manager = scheduler.kv_cache_manager().unwrap();
596
597 for seq in guards_mut.iter_mut() {
598 let cached_prefix_len = seq.prefix_cache_len();
599 if cached_prefix_len == 0 {
600 continue;
601 }
602
603 let mut fallback_to_full_prompt = false;
604
605 let slot_idx = match seq.recurrent_state_idx() {
606 Some(idx) => idx,
607 None => {
608 tracing::warn!("Sequence {} has paged prefix hit but no recurrent_state_idx; recomputing full prompt.", seq.id());
609 fallback_to_full_prompt = true;
610 0usize
612 }
613 };
614
615 if !fallback_to_full_prompt {
616 if cached_prefix_len % block_size != 0 {
617 tracing::warn!(
618 "Sequence {} has non-aligned paged prefix len {}; recomputing full prompt.",
619 seq.id(),
620 cached_prefix_len
621 );
622 fallback_to_full_prompt = true;
623 } else {
624 let num_prefix_blocks = cached_prefix_len / block_size;
625 let block_hashes = compute_block_hashes(
626 seq.get_toks(),
627 block_size,
628 seq.mm_features(),
629 &[],
630 );
631 if block_hashes.len() < num_prefix_blocks {
632 fallback_to_full_prompt = true;
633 } else if let Some(snapshots) = prefix_cacher
634 .get_paged_recurrent_prefix(
635 &block_hashes[..num_prefix_blocks],
636 )
637 {
638 if let Err(e) = hybrid_cache
639 .restore_recurrent_state(slot_idx, &snapshots)
640 {
641 tracing::warn!(
642 "Failed restoring paged recurrent prefix state for sequence {}: {e}",
643 seq.id()
644 );
645 fallback_to_full_prompt = true;
646 }
647 } else {
648 tracing::warn!(
649 "No recurrent prefix snapshot for sequence {} at cached prefix length {}; recomputing full prompt.",
650 seq.id(),
651 cached_prefix_len
652 );
653 fallback_to_full_prompt = true;
654 }
655 }
656 }
657
658 if fallback_to_full_prompt {
659 let seq_id = *seq.id();
660 let num_tokens = seq.get_toks().len();
661 let mut kv_mgr = get_mut_arcmutex!(kv_cache_manager);
662 kv_mgr.free(seq_id);
663 let realloc_ok = kv_mgr
664 .allocate_slots(seq_id, num_tokens, &[])
665 .is_some();
666 drop(kv_mgr);
667
668 if !realloc_ok {
669 tracing::warn!(
670 "Failed to reallocate fresh paged KV blocks for sequence {} after recurrent-prefix fallback.",
671 seq_id
672 );
673 seq.set_state(SequenceState::FinishedIgnored);
674 }
675 seq.set_prefix_cache_len(0);
676 }
677 }
678
679 guards_mut.retain(|seq| !seq.is_finished_paged_attn());
681 }
682
683 if guards_mut.is_empty() {
684 Ok(Duration::ZERO)
685 } else {
686 let metadata = PagedAttentionMeta {
687 block_size,
688 sliding_window: pipeline.get_metadata().sliding_window,
689 kv_cache_manager: scheduler.kv_cache_manager().unwrap(),
690 };
691
692 let return_raw_logits = guards_mut[0].return_raw_logits;
693 assert!(
694 guards_mut
695 .iter()
696 .all(|seq| seq.return_raw_logits == return_raw_logits),
697 "All sequences must either return raw logits, or not."
698 );
699
700 pipeline
701 .step(
702 &mut guards_mut,
703 is_prompt,
704 return_raw_logits,
705 &mut *get_mut_arcmutex!(self.prefix_cacher),
706 self.disable_eos_stop,
707 rng.clone(),
708 CacheBackendMetadata::PagedAttention { metadata },
709 )
710 .await
711 }
712 };
713
714 handle_pipeline_forward_error!(
715 "step",
716 res,
717 &mut guards_mut,
718 self.pipeline,
719 'lp,
720 self.prefix_cacher
721 );
722
723 let total_processed_tokens: usize = guards_mut
724 .iter()
725 .map(|seq| {
726 if seq.is_prompt() {
727 seq.get_toks().len()
728 } else {
729 1
730 }
731 })
732 .sum();
733 self.logger.add_tokens_processed(total_processed_tokens);
734
735 {
738 let pipeline = get_mut_arcmutex!(self.pipeline);
739 if pipeline.cache().is_hybrid() {
740 let block_size = scheduler.block_size().unwrap();
741 let hybrid_cache = pipeline.cache().hybrid();
742 let mut prefix_cacher = get_mut_arcmutex!(self.prefix_cacher);
743
744 for seq in guards_mut.iter() {
745 let seq_len = seq.get_toks().len();
746 if seq_len == 0 || seq_len % block_size != 0 {
747 continue;
748 }
749
750 let Some(slot_idx) = seq.recurrent_state_idx() else {
751 continue;
752 };
753
754 let snapshots = match hybrid_cache
755 .snapshot_recurrent_state(slot_idx)
756 {
757 Ok(snapshots) => snapshots,
758 Err(e) => {
759 tracing::warn!(
760 "Failed snapshotting recurrent state for sequence {}: {e}",
761 seq.id()
762 );
763 continue;
764 }
765 };
766 if snapshots.is_empty() {
767 continue;
768 }
769
770 let num_blocks = seq_len / block_size;
771 let block_hashes = compute_block_hashes(
772 seq.get_toks(),
773 block_size,
774 seq.mm_features(),
775 &[],
776 );
777 if block_hashes.len() < num_blocks {
778 continue;
779 }
780 prefix_cacher.add_paged_recurrent_prefix(
781 block_hashes[..num_blocks].to_vec(),
782 snapshots,
783 );
784 }
785 }
786 }
787
788 if self.is_debug {
789 let ms_from_last_run = run_start.elapsed().as_secs_f64();
790 let total_len = guards.len();
791 if total_len > 0 {
792 let lengths = guards
793 .iter()
794 .map(|seq| seq.len().to_string())
795 .collect::<Vec<_>>()
796 .join(", ");
797
798 let (prompt_lengths, completion_lengths) = if is_prompt {
799 (lengths, "".to_string())
800 } else {
801 ("".to_string(), lengths)
802 };
803
804 tracing::info!(
805 "Prompt[{}] Completion[{}] - {}ms",
806 prompt_lengths,
807 completion_lengths,
808 ms_from_last_run * 1000.,
809 );
810 }
811 }
812
813 if is_prompt {
814 #[allow(clippy::cast_precision_loss)]
815 for mut seq in guards {
816 if let Some(start) = seq.step_start_instant {
818 let duration = start.elapsed();
819 seq.prompt_tok_per_sec =
820 seq.len() as f32 / duration.as_secs_f32();
821 seq.total_prompt_time = Some(duration.as_millis());
822 seq.step_start_instant = None;
823 }
824 let now = SystemTime::now()
825 .duration_since(UNIX_EPOCH)
826 .expect("Time travel has occurred!")
827 .as_millis();
828 seq.prompt_timestamp = Some(now);
829 }
830 }
831 }
832 }
833 }
834
835 {
837 let pipeline = get_mut_arcmutex!(self.pipeline);
838 if !pipeline.get_metadata().no_kv_cache && pipeline.cache().is_hybrid() {
839 let recurrent_indices = scheduler.get_finished_recurrent_indices();
840 if !recurrent_indices.is_empty() {
841 let mut hybrid_cache = pipeline.cache().hybrid();
842 for idx in recurrent_indices {
843 hybrid_cache.free_seq(idx);
844 }
845 }
846 }
847 }
848 scheduler.free_finished_sequence_groups();
849 }
850 }
851
852 fn build_sequence_recognizer(
853 factory: &Option<Arc<ParserFactory>>,
854 constraint: &Constraint,
855 ) -> anyhow::Result<SequenceRecognizer> {
856 if let Some(grm) = llg_grammar_from_constraint(constraint)? {
857 let factory = factory
858 .as_ref()
859 .ok_or_else(|| anyhow::anyhow!("No token environment (llg_factory) found."))?;
860 let llg = constraint_from_llg_grammar(factory, grm)?;
861 Ok(SequenceRecognizer::Llguidance(Box::new(llg)))
862 } else {
863 Ok(SequenceRecognizer::None)
864 }
865 }
866
867 fn replicate_request_to_daemons(&self, request: &Request) {
868 if !distributed::is_daemon() && hanzo_quant::distributed::use_nccl() {
869 let name = distributed::ipc_name().unwrap();
870 let num_workers =
871 hanzo_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
872 let listener = ListenerOptions::new().name(name).create_sync().unwrap();
873
874 for _ in 0..num_workers {
875 let stream = listener.accept().unwrap();
876 let mut writer = BufWriter::new(stream);
877 let req = format!("{}\n", serde_json::to_string(&request).unwrap());
878 writer.write_all(req.as_bytes()).unwrap();
879 }
880 } else if !distributed::is_daemon() && cfg!(feature = "ring") {
881 let num_workers =
882 hanzo_quant::distributed::get_global_tp_size_from_devices().unwrap() - 1;
883 let master_port = RingConfig::load().master_port;
884 let listener =
885 TcpListener::bind(format!("0.0.0.0:{master_port}")).expect("bind replicator");
886
887 for _ in 0..num_workers {
888 let (stream, _) = listener.accept().unwrap();
889 let mut writer = BufWriter::new(stream);
890 let req = format!("{}\n", serde_json::to_string(&request).unwrap());
891 writer.write_all(req.as_bytes()).unwrap();
892 }
893 }
894 }
895}