1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
use std::collections::HashMap;
use std::sync::Arc;
use tokio::sync::Mutex;
use chrono::{DateTime, Utc};
use vllora_core::handler::CallbackHandlerFn;
use vllora_core::usage::InMemoryStorage;
use vllora_llm::types::gateway::ImageGenerationModelUsage;
use vllora_llm::types::ModelEventType;
use crate::{cost::GatewayCostCalculator, usage::update_usage};
pub fn init_callback_handler(
storage: Arc<Mutex<InMemoryStorage>>,
calculator: GatewayCostCalculator,
) -> CallbackHandlerFn {
let (tx, mut rx) = tokio::sync::broadcast::channel(10000);
let start_times = Arc::new(Mutex::new(HashMap::<String, DateTime<Utc>>::new()));
let ttft_times = Arc::new(Mutex::new(HashMap::<String, i64>::new()));
let callback_handler = CallbackHandlerFn(Some(tx));
tokio::spawn({
let start_times = start_times.clone();
let ttft_times = ttft_times.clone();
async move {
loop {
if let Ok(model_event) = rx.recv().await {
tracing::debug!(target: "model_event", "Received model event: {model_event:#?}");
match &model_event.event.event {
ModelEventType::LlmStart(_) => {
let mut times = start_times.lock().await;
times.insert(
model_event.event.trace_id.clone(),
model_event.event.timestamp,
);
tracing::debug!(
"Recorded LlmStart time for trace {}",
model_event.event.trace_id
);
}
ModelEventType::LlmFirstToken(_) => {
let ttft = {
let times = start_times.lock().await;
if let Some(start_time) = times.get(&model_event.event.trace_id) {
let duration = model_event.event.timestamp - *start_time;
let ttft_ms = duration.num_milliseconds();
let mut ttft_map = ttft_times.lock().await;
ttft_map.insert(model_event.event.trace_id.clone(), ttft_ms);
Some(ttft_ms)
} else {
tracing::warn!(
"No start time found for trace {}",
model_event.event.trace_id
);
None
}
};
if let Some(ttft_ms) = ttft {
tracing::info!(
"TTFT for trace {}: {} milliseconds",
model_event.event.trace_id,
ttft_ms
);
}
}
ModelEventType::LlmStop(finish_event) => {
let model_name = finish_event.model_name.clone();
let usage = finish_event.usage.clone();
// Calculate duration and get ttft
let (duration, ttft) = {
let mut times = start_times.lock().await;
let mut ttft_map = ttft_times.lock().await;
let duration =
times.remove(&model_event.event.trace_id).map(|start_time| {
let duration = model_event.event.timestamp - start_time;
duration.num_milliseconds()
});
if duration.is_none() {
tracing::warn!(
"No start time found for trace {}",
model_event.event.trace_id
);
}
let ttft = ttft_map.remove(&model_event.event.trace_id);
(duration, ttft)
};
if let Some(model) = &model_event.model {
let result = update_usage(
storage.clone(),
&calculator,
&model_name,
&model.provider_name,
usage
.map(
vllora_llm::types::gateway::Usage::CompletionModelUsage,
)
.as_ref(),
duration.map(|d| d as u64),
ttft.map(|t| t as u64),
&model.price,
)
.await;
if let Err(e) = result {
tracing::error!("Error setting model usage: {e}");
};
}
}
ModelEventType::ImageGenerationFinish(finish_event) => {
if let Some(model) = &model_event.model {
let model_name = finish_event.model_name.clone();
let result = update_usage(
storage.clone(),
&calculator,
&model_name,
&model.provider_name,
Some(
&vllora_llm::types::gateway::Usage::ImageGenerationModelUsage(
ImageGenerationModelUsage {
quality: finish_event.quality.clone(),
size: finish_event.size.clone().into(),
images_count: finish_event.count_of_images,
steps_count: finish_event.steps,
},
),
),
None,
None,
&model.price,
)
.await;
if let Err(e) = result {
tracing::error!("Error setting model usage: {e}");
}
}
}
_ => {}
}
}
}
}
});
callback_handler
}