1use crate::telemetry::{
4 BufferingConfig, PlatformTelemetrySubscription, TelemetryEvent, TelemetryEventType,
5 TelemetrySubscription,
6};
7use chrono::Utc;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::{Mutex, RwLock};
12use tokio::task::JoinHandle;
13
14#[derive(Debug, Clone)]
16pub(crate) struct Subscription {
17 #[allow(dead_code)]
19 pub extension_id: String,
20
21 #[allow(dead_code)]
23 pub extension_name: String,
24
25 pub event_types: Vec<TelemetryEventType>,
27
28 pub destination_uri: String,
30
31 #[allow(dead_code)]
33 pub buffering: BufferingConfig,
34}
35
36impl Default for BufferingConfig {
37 fn default() -> Self {
38 Self {
39 max_items: Some(10000),
40 max_bytes: Some(262144),
41 timeout_ms: Some(1000),
42 }
43 }
44}
45
46const MAX_CAPTURED_EVENTS: usize = 10000;
48
49#[derive(Debug)]
51pub(crate) struct TelemetryState {
52 subscriptions: Mutex<HashMap<String, Subscription>>,
54
55 event_buffers: Mutex<HashMap<String, Vec<TelemetryEvent>>>,
57
58 delivery_handles: Mutex<HashMap<String, JoinHandle<()>>>,
60
61 http_client: reqwest::Client,
63 http1_client: reqwest::Client,
66
67 capture_mode: Mutex<bool>,
69
70 captured_events: Mutex<Vec<TelemetryEvent>>,
75
76 flush_lock: RwLock<()>,
82}
83
84impl TelemetryState {
85 pub fn new() -> Self {
87 let http1_client = reqwest::Client::builder()
90 .http1_only()
91 .build()
92 .unwrap_or_else(|_| reqwest::Client::new());
93
94 Self {
95 subscriptions: Mutex::new(HashMap::new()),
96 event_buffers: Mutex::new(HashMap::new()),
97 delivery_handles: Mutex::new(HashMap::new()),
98 http_client: reqwest::Client::new(),
99 http1_client,
100 capture_mode: Mutex::new(false),
101 captured_events: Mutex::new(Vec::new()),
102 flush_lock: RwLock::new(()),
103 }
104 }
105
106 pub async fn subscribe(
114 self: &Arc<Self>,
115 extension_id: String,
116 extension_name: String,
117 subscription: TelemetrySubscription,
118 ) {
119 let buffering = subscription.buffering.unwrap_or_default();
120
121 let destination_uri = subscription
124 .destination
125 .uri
126 .replace("sandbox.localdomain", "127.0.0.1");
127
128 let subscribed_types: Vec<String> = subscription
129 .types
130 .iter()
131 .map(|t| format!("{:?}", t).to_lowercase())
132 .collect();
133
134 let sub = Subscription {
135 extension_id: extension_id.clone(),
136 extension_name: extension_name.clone(),
137 event_types: subscription.types,
138 destination_uri,
139 buffering: buffering.clone(),
140 };
141
142 self.subscriptions
143 .lock()
144 .await
145 .insert(extension_id.clone(), sub);
146
147 self.event_buffers
148 .lock()
149 .await
150 .insert(extension_id.clone(), Vec::new());
151
152 self.start_delivery_task(extension_id, buffering).await;
153
154 let subscription_event = TelemetryEvent {
155 time: Utc::now(),
156 event_type: "platform.telemetrySubscription".to_string(),
157 record: serde_json::to_value(PlatformTelemetrySubscription {
158 name: extension_name,
159 state: "Subscribed".to_string(),
160 types: subscribed_types,
161 })
162 .unwrap_or_default(),
163 };
164
165 self.broadcast_event(subscription_event, TelemetryEventType::Platform)
166 .await;
167 }
168
169 async fn start_delivery_task(
171 self: &Arc<Self>,
172 extension_id: String,
173 buffering: BufferingConfig,
174 ) {
175 let state = Arc::clone(self);
176 let timeout_ms = buffering.timeout_ms.unwrap_or(1000);
177 let max_items = buffering.max_items.unwrap_or(10000);
178 let ext_id_for_insert = extension_id.clone();
179
180 let handle = tokio::spawn(async move {
181 let mut interval = tokio::time::interval(Duration::from_millis(timeout_ms as u64));
182
183 loop {
184 interval.tick().await;
185
186 let events = {
187 let mut buffers = state.event_buffers.lock().await;
188 if let Some(buffer) = buffers.get_mut(&extension_id) {
189 if buffer.is_empty() {
190 continue;
191 }
192
193 let count = buffer.len().min(max_items as usize);
194 buffer.drain(..count).collect::<Vec<_>>()
195 } else {
196 break;
197 }
198 };
199
200 if let Some(sub) = state.subscriptions.lock().await.get(&extension_id) {
201 let uri = sub.destination_uri.clone();
202 let client = state.http_client.clone();
203
204 tracing::debug!(
205 count = events.len(),
206 uri = %uri,
207 "Sending telemetry events"
208 );
209 match client.post(&uri).json(&events).send().await {
210 Ok(resp) => {
211 tracing::debug!(status = %resp.status(), "Telemetry delivery response");
212 }
213 Err(e) => {
214 tracing::warn!(error = %e, "Telemetry delivery error");
215 }
216 }
217 }
218 }
219 });
220
221 self.delivery_handles
222 .lock()
223 .await
224 .insert(ext_id_for_insert, handle);
225 }
226
227 pub async fn broadcast_event(&self, event: TelemetryEvent, event_type: TelemetryEventType) {
238 {
240 let mut captured = self.captured_events.lock().await;
241 if captured.len() >= MAX_CAPTURED_EVENTS {
242 captured.remove(0);
244 }
245 captured.push(event.clone());
246 }
247
248 let subscriptions = self.subscriptions.lock().await;
249 let mut buffers = self.event_buffers.lock().await;
250
251 tracing::trace!(
252 event_type = ?event_type,
253 subscriptions = subscriptions.len(),
254 buffers = buffers.len(),
255 "Broadcasting telemetry event"
256 );
257
258 for (ext_id, sub) in subscriptions.iter() {
259 tracing::trace!(
260 extension_id = %ext_id,
261 event_types = ?sub.event_types,
262 matches = sub.event_types.contains(&event_type),
263 "Checking subscription"
264 );
265 if sub.event_types.contains(&event_type)
266 && let Some(buffer) = buffers.get_mut(ext_id)
267 {
268 tracing::trace!(extension_id = %ext_id, "Adding event to buffer");
269 let max_items = sub.buffering.max_items.unwrap_or(10000) as usize;
270 if buffer.len() >= max_items {
271 let excess = buffer.len() - max_items + 1;
273 buffer.drain(..excess);
274 tracing::warn!(
275 extension_id = %ext_id,
276 dropped_events = excess,
277 "Telemetry buffer overflow, dropped oldest events"
278 );
279 }
280 buffer.push(event.clone());
281 }
282 }
283 }
284
285 #[allow(dead_code)]
287 pub async fn get_subscriptions(&self) -> Vec<Subscription> {
288 self.subscriptions.lock().await.values().cloned().collect()
289 }
290
291 #[allow(dead_code)]
293 pub async fn is_subscribed(&self, extension_id: &str) -> bool {
294 self.subscriptions.lock().await.contains_key(extension_id)
295 }
296
297 pub async fn enable_capture(&self) {
302 *self.capture_mode.lock().await = true;
303 }
304
305 pub async fn get_captured_events(&self) -> Vec<TelemetryEvent> {
309 self.captured_events.lock().await.clone()
310 }
311
312 pub async fn get_captured_events_by_type(&self, event_type: &str) -> Vec<TelemetryEvent> {
318 self.captured_events
319 .lock()
320 .await
321 .iter()
322 .filter(|e| e.event_type == event_type)
323 .cloned()
324 .collect()
325 }
326
327 pub async fn clear_captured_events(&self) {
329 self.captured_events.lock().await.clear();
330 }
331
332 pub async fn flush_all(&self) {
341 tracing::debug!("Starting flush_all");
342 let _guard = self.flush_lock.read().await;
344
345 let subscriptions = self.subscriptions.lock().await;
346 let mut buffers = self.event_buffers.lock().await;
347
348 tracing::debug!(
349 subscriptions = subscriptions.len(),
350 buffers = buffers.len(),
351 "Flushing telemetry buffers"
352 );
353
354 for (ext_id, sub) in subscriptions.iter() {
355 if let Some(buffer) = buffers.get_mut(ext_id) {
356 if buffer.is_empty() {
357 tracing::trace!(extension_id = %ext_id, "Buffer empty, skipping");
358 continue;
359 }
360
361 let events = std::mem::take(buffer);
362 let uri = sub.destination_uri.clone();
363
364 tracing::debug!(
365 extension_id = %ext_id,
366 count = events.len(),
367 uri = %uri,
368 "Flushing events to extension"
369 );
370
371 let mut attempts = 0;
374 let max_attempts = 5;
375
376 loop {
377 attempts += 1;
378 match self.http1_client.post(&uri).json(&events).send().await {
381 Ok(resp) if resp.status().is_success() => {
382 tracing::debug!(
383 status = %resp.status(),
384 attempts,
385 "Flush successful"
386 );
387 break;
388 }
389 Ok(resp) => {
390 let status = resp.status();
391 let body = resp.text().await.unwrap_or_default();
392 tracing::debug!(
393 status = %status,
394 body = %body,
395 attempts,
396 "Flush attempt failed"
397 );
398 if attempts >= max_attempts {
399 tracing::warn!(
400 extension_id = %ext_id,
401 status = %status,
402 "Failed to flush telemetry events after {} attempts",
403 max_attempts
404 );
405 break;
406 }
407 tokio::time::sleep(Duration::from_millis(200)).await;
408 }
409 Err(e) => {
410 tracing::debug!(
411 error = %e,
412 attempts,
413 "Flush attempt error"
414 );
415 if attempts >= max_attempts {
416 tracing::warn!(
417 extension_id = %ext_id,
418 error = %e,
419 "Failed to flush telemetry events after {} attempts",
420 max_attempts
421 );
422 break;
423 }
424 tokio::time::sleep(Duration::from_millis(200)).await;
425 }
426 }
427 }
428 }
429 }
430 tracing::debug!("flush_all complete");
431 }
432
433 pub async fn wait_for_flush_complete(&self, timeout: Duration) {
443 let result = tokio::time::timeout(timeout, self.flush_lock.write()).await;
444 if result.is_err() {
445 tracing::warn!(
446 timeout_ms = timeout.as_millis(),
447 "Timed out waiting for flush operations to complete"
448 );
449 }
450 }
452
453 pub async fn shutdown(&self) {
457 let mut handles = self.delivery_handles.lock().await;
458 for (_, handle) in handles.drain() {
459 handle.abort();
460 }
461 }
462}
463
464impl Default for TelemetryState {
465 fn default() -> Self {
466 Self::new()
467 }
468}