lambda_simulator/
telemetry_state.rs

1//! Telemetry subscription state management.
2
3use crate::telemetry::{
4    BufferingConfig, PlatformTelemetrySubscription, TelemetryEvent, TelemetryEventType,
5    TelemetrySubscription,
6};
7use chrono::Utc;
8use std::collections::HashMap;
9use std::sync::Arc;
10use std::time::Duration;
11use tokio::sync::{Mutex, RwLock};
12use tokio::task::JoinHandle;
13
14/// Information about an extension's telemetry subscription.
15#[derive(Debug, Clone)]
16pub(crate) struct Subscription {
17    /// Extension identifier.
18    #[allow(dead_code)]
19    pub extension_id: String,
20
21    /// Extension name.
22    #[allow(dead_code)]
23    pub extension_name: String,
24
25    /// Types of events subscribed to.
26    pub event_types: Vec<TelemetryEventType>,
27
28    /// HTTP URI to send events to.
29    pub destination_uri: String,
30
31    /// Buffering configuration.
32    #[allow(dead_code)]
33    pub buffering: BufferingConfig,
34}
35
36impl Default for BufferingConfig {
37    fn default() -> Self {
38        Self {
39            max_items: Some(10000),
40            max_bytes: Some(262144),
41            timeout_ms: Some(1000),
42        }
43    }
44}
45
46/// Maximum number of events stored in the internal test capture buffer.
47const MAX_CAPTURED_EVENTS: usize = 10000;
48
49/// State for managing telemetry subscriptions and event delivery.
50#[derive(Debug)]
51pub(crate) struct TelemetryState {
52    /// Map of extension ID to subscription info.
53    subscriptions: Mutex<HashMap<String, Subscription>>,
54
55    /// Buffered events per extension.
56    event_buffers: Mutex<HashMap<String, Vec<TelemetryEvent>>>,
57
58    /// Background tasks for event delivery.
59    delivery_handles: Mutex<HashMap<String, JoinHandle<()>>>,
60
61    /// HTTP client for sending events.
62    http_client: reqwest::Client,
63    /// HTTP/1.1 only client for telemetry delivery.
64    /// The lambda_extension crate uses a simple hyper http1 server that doesn't support HTTP/2.
65    http1_client: reqwest::Client,
66
67    /// Test mode: capture all events in memory instead of sending via HTTP.
68    capture_mode: Mutex<bool>,
69
70    /// Captured events when in test mode.
71    ///
72    /// This buffer is bounded to `MAX_CAPTURED_EVENTS`. When full, the oldest
73    /// events are dropped to make room for new ones.
74    captured_events: Mutex<Vec<TelemetryEvent>>,
75
76    /// Lock to synchronise flush operations with shutdown.
77    ///
78    /// Flush operations acquire this lock while sending events to extensions.
79    /// Shutdown waits to acquire this lock before proceeding, ensuring all
80    /// in-flight flush operations complete before SHUTDOWN is sent.
81    flush_lock: RwLock<()>,
82}
83
84impl TelemetryState {
85    /// Creates a new telemetry state.
86    pub fn new() -> Self {
87        // Create HTTP/1.1 only client for telemetry delivery
88        // The lambda_extension crate uses a simple hyper http1 server
89        let http1_client = reqwest::Client::builder()
90            .http1_only()
91            .build()
92            .unwrap_or_else(|_| reqwest::Client::new());
93
94        Self {
95            subscriptions: Mutex::new(HashMap::new()),
96            event_buffers: Mutex::new(HashMap::new()),
97            delivery_handles: Mutex::new(HashMap::new()),
98            http_client: reqwest::Client::new(),
99            http1_client,
100            capture_mode: Mutex::new(false),
101            captured_events: Mutex::new(Vec::new()),
102            flush_lock: RwLock::new(()),
103        }
104    }
105
106    /// Subscribes an extension to telemetry events.
107    ///
108    /// # Arguments
109    ///
110    /// * `extension_id` - The extension's identifier
111    /// * `extension_name` - The extension's name
112    /// * `subscription` - Subscription configuration
113    pub async fn subscribe(
114        self: &Arc<Self>,
115        extension_id: String,
116        extension_name: String,
117        subscription: TelemetrySubscription,
118    ) {
119        let buffering = subscription.buffering.unwrap_or_default();
120
121        // The lambda_extension crate uses sandbox.localdomain as the destination host,
122        // which doesn't resolve in local test environments. Replace with 127.0.0.1.
123        let destination_uri = subscription
124            .destination
125            .uri
126            .replace("sandbox.localdomain", "127.0.0.1");
127
128        let subscribed_types: Vec<String> = subscription
129            .types
130            .iter()
131            .map(|t| format!("{:?}", t).to_lowercase())
132            .collect();
133
134        let sub = Subscription {
135            extension_id: extension_id.clone(),
136            extension_name: extension_name.clone(),
137            event_types: subscription.types,
138            destination_uri,
139            buffering: buffering.clone(),
140        };
141
142        self.subscriptions
143            .lock()
144            .await
145            .insert(extension_id.clone(), sub);
146
147        self.event_buffers
148            .lock()
149            .await
150            .insert(extension_id.clone(), Vec::new());
151
152        self.start_delivery_task(extension_id, buffering).await;
153
154        let subscription_event = TelemetryEvent {
155            time: Utc::now(),
156            event_type: "platform.telemetrySubscription".to_string(),
157            record: serde_json::to_value(PlatformTelemetrySubscription {
158                name: extension_name,
159                state: "Subscribed".to_string(),
160                types: subscribed_types,
161            })
162            .unwrap_or_default(),
163        };
164
165        self.broadcast_event(subscription_event, TelemetryEventType::Platform)
166            .await;
167    }
168
169    /// Starts a background task to deliver buffered events.
170    async fn start_delivery_task(
171        self: &Arc<Self>,
172        extension_id: String,
173        buffering: BufferingConfig,
174    ) {
175        let state = Arc::clone(self);
176        let timeout_ms = buffering.timeout_ms.unwrap_or(1000);
177        let max_items = buffering.max_items.unwrap_or(10000);
178        let ext_id_for_insert = extension_id.clone();
179
180        let handle = tokio::spawn(async move {
181            let mut interval = tokio::time::interval(Duration::from_millis(timeout_ms as u64));
182
183            loop {
184                interval.tick().await;
185
186                let events = {
187                    let mut buffers = state.event_buffers.lock().await;
188                    if let Some(buffer) = buffers.get_mut(&extension_id) {
189                        if buffer.is_empty() {
190                            continue;
191                        }
192
193                        let count = buffer.len().min(max_items as usize);
194                        buffer.drain(..count).collect::<Vec<_>>()
195                    } else {
196                        break;
197                    }
198                };
199
200                if let Some(sub) = state.subscriptions.lock().await.get(&extension_id) {
201                    let uri = sub.destination_uri.clone();
202                    let client = state.http_client.clone();
203
204                    tracing::debug!(
205                        count = events.len(),
206                        uri = %uri,
207                        "Sending telemetry events"
208                    );
209                    match client.post(&uri).json(&events).send().await {
210                        Ok(resp) => {
211                            tracing::debug!(status = %resp.status(), "Telemetry delivery response");
212                        }
213                        Err(e) => {
214                            tracing::warn!(error = %e, "Telemetry delivery error");
215                        }
216                    }
217                }
218            }
219        });
220
221        self.delivery_handles
222            .lock()
223            .await
224            .insert(ext_id_for_insert, handle);
225    }
226
227    /// Broadcasts a telemetry event to all subscribed extensions.
228    ///
229    /// Events are always captured internally for test introspection, even
230    /// if capture mode is not explicitly enabled. All buffers are bounded:
231    /// when full, the oldest events are dropped to make room for new ones.
232    ///
233    /// # Arguments
234    ///
235    /// * `event` - The telemetry event to broadcast
236    /// * `event_type` - The type of the event (platform, function, extension)
237    pub async fn broadcast_event(&self, event: TelemetryEvent, event_type: TelemetryEventType) {
238        // Always capture events for test introspection (bounded buffer)
239        {
240            let mut captured = self.captured_events.lock().await;
241            if captured.len() >= MAX_CAPTURED_EVENTS {
242                // Drop oldest event to make room
243                captured.remove(0);
244            }
245            captured.push(event.clone());
246        }
247
248        let subscriptions = self.subscriptions.lock().await;
249        let mut buffers = self.event_buffers.lock().await;
250
251        tracing::trace!(
252            event_type = ?event_type,
253            subscriptions = subscriptions.len(),
254            buffers = buffers.len(),
255            "Broadcasting telemetry event"
256        );
257
258        for (ext_id, sub) in subscriptions.iter() {
259            tracing::trace!(
260                extension_id = %ext_id,
261                event_types = ?sub.event_types,
262                matches = sub.event_types.contains(&event_type),
263                "Checking subscription"
264            );
265            if sub.event_types.contains(&event_type)
266                && let Some(buffer) = buffers.get_mut(ext_id)
267            {
268                tracing::trace!(extension_id = %ext_id, "Adding event to buffer");
269                let max_items = sub.buffering.max_items.unwrap_or(10000) as usize;
270                if buffer.len() >= max_items {
271                    // Drop oldest events to stay within limit
272                    let excess = buffer.len() - max_items + 1;
273                    buffer.drain(..excess);
274                    tracing::warn!(
275                        extension_id = %ext_id,
276                        dropped_events = excess,
277                        "Telemetry buffer overflow, dropped oldest events"
278                    );
279                }
280                buffer.push(event.clone());
281            }
282        }
283    }
284
285    /// Gets all active subscriptions.
286    #[allow(dead_code)]
287    pub async fn get_subscriptions(&self) -> Vec<Subscription> {
288        self.subscriptions.lock().await.values().cloned().collect()
289    }
290
291    /// Checks if an extension has an active telemetry subscription.
292    #[allow(dead_code)]
293    pub async fn is_subscribed(&self, extension_id: &str) -> bool {
294        self.subscriptions.lock().await.contains_key(extension_id)
295    }
296
297    /// Enables test mode telemetry capture.
298    ///
299    /// When enabled, all telemetry events are captured in memory
300    /// instead of being sent via HTTP. This is useful for test assertions.
301    pub async fn enable_capture(&self) {
302        *self.capture_mode.lock().await = true;
303    }
304
305    /// Gets all captured telemetry events.
306    ///
307    /// Returns an empty vector if capture mode is not enabled.
308    pub async fn get_captured_events(&self) -> Vec<TelemetryEvent> {
309        self.captured_events.lock().await.clone()
310    }
311
312    /// Gets captured telemetry events filtered by event type.
313    ///
314    /// # Arguments
315    ///
316    /// * `event_type` - The event type to filter by (e.g., "platform.start")
317    pub async fn get_captured_events_by_type(&self, event_type: &str) -> Vec<TelemetryEvent> {
318        self.captured_events
319            .lock()
320            .await
321            .iter()
322            .filter(|e| e.event_type == event_type)
323            .cloned()
324            .collect()
325    }
326
327    /// Clears all captured telemetry events.
328    pub async fn clear_captured_events(&self) {
329        self.captured_events.lock().await.clear();
330    }
331
332    /// Flushes all buffered telemetry events to subscribers immediately.
333    ///
334    /// This bypasses the normal interval-based delivery and sends all pending
335    /// events right away. Useful before shutdown to ensure extensions receive
336    /// all telemetry.
337    ///
338    /// This method holds the flush lock while sending, which allows
339    /// `wait_for_flush_complete` to synchronise with ongoing flush operations.
340    pub async fn flush_all(&self) {
341        tracing::debug!("Starting flush_all");
342        // Acquire read lock - allows concurrent flushes but blocks shutdown waiting
343        let _guard = self.flush_lock.read().await;
344
345        let subscriptions = self.subscriptions.lock().await;
346        let mut buffers = self.event_buffers.lock().await;
347
348        tracing::debug!(
349            subscriptions = subscriptions.len(),
350            buffers = buffers.len(),
351            "Flushing telemetry buffers"
352        );
353
354        for (ext_id, sub) in subscriptions.iter() {
355            if let Some(buffer) = buffers.get_mut(ext_id) {
356                if buffer.is_empty() {
357                    tracing::trace!(extension_id = %ext_id, "Buffer empty, skipping");
358                    continue;
359                }
360
361                let events = std::mem::take(buffer);
362                let uri = sub.destination_uri.clone();
363
364                tracing::debug!(
365                    extension_id = %ext_id,
366                    count = events.len(),
367                    uri = %uri,
368                    "Flushing events to extension"
369                );
370
371                // Send with retries to handle cases where extension's HTTP server
372                // isn't fully ready yet (especially during rapid test execution)
373                let mut attempts = 0;
374                let max_attempts = 5;
375
376                loop {
377                    attempts += 1;
378                    // Use http1_client as lambda_extension uses a simple hyper http1 server
379                    // Use .json() instead of .body() to ensure proper serialization
380                    match self.http1_client.post(&uri).json(&events).send().await {
381                        Ok(resp) if resp.status().is_success() => {
382                            tracing::debug!(
383                                status = %resp.status(),
384                                attempts,
385                                "Flush successful"
386                            );
387                            break;
388                        }
389                        Ok(resp) => {
390                            let status = resp.status();
391                            let body = resp.text().await.unwrap_or_default();
392                            tracing::debug!(
393                                status = %status,
394                                body = %body,
395                                attempts,
396                                "Flush attempt failed"
397                            );
398                            if attempts >= max_attempts {
399                                tracing::warn!(
400                                    extension_id = %ext_id,
401                                    status = %status,
402                                    "Failed to flush telemetry events after {} attempts",
403                                    max_attempts
404                                );
405                                break;
406                            }
407                            tokio::time::sleep(Duration::from_millis(200)).await;
408                        }
409                        Err(e) => {
410                            tracing::debug!(
411                                error = %e,
412                                attempts,
413                                "Flush attempt error"
414                            );
415                            if attempts >= max_attempts {
416                                tracing::warn!(
417                                    extension_id = %ext_id,
418                                    error = %e,
419                                    "Failed to flush telemetry events after {} attempts",
420                                    max_attempts
421                                );
422                                break;
423                            }
424                            tokio::time::sleep(Duration::from_millis(200)).await;
425                        }
426                    }
427                }
428            }
429        }
430        tracing::debug!("flush_all complete");
431    }
432
433    /// Waits for any in-progress flush operations to complete.
434    ///
435    /// This acquires a write lock on the flush lock, which blocks until all
436    /// concurrent read locks (held by `flush_all`) are released. Use this
437    /// before sending SHUTDOWN to ensure extensions have received all telemetry.
438    ///
439    /// The timeout parameter specifies how long to wait for flush operations
440    /// to complete. If the timeout expires, this method returns without
441    /// guaranteeing all flushes are complete.
442    pub async fn wait_for_flush_complete(&self, timeout: Duration) {
443        let result = tokio::time::timeout(timeout, self.flush_lock.write()).await;
444        if result.is_err() {
445            tracing::warn!(
446                timeout_ms = timeout.as_millis(),
447                "Timed out waiting for flush operations to complete"
448            );
449        }
450        // Lock is immediately dropped, we just needed to wait for it
451    }
452
453    /// Shuts down all background telemetry delivery tasks.
454    ///
455    /// This aborts all spawned delivery tasks to ensure clean shutdown.
456    pub async fn shutdown(&self) {
457        let mut handles = self.delivery_handles.lock().await;
458        for (_, handle) in handles.drain() {
459            handle.abort();
460        }
461    }
462}
463
464impl Default for TelemetryState {
465    fn default() -> Self {
466        Self::new()
467    }
468}