cryprot-net 0.2.3

Networking library for cryptographic protocols built on QUIC.
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
//! [`tracing_subscriber::Layer`] for structured communication metrics.
//!
//! The [`CommLayer`] is a [`tracing_subscriber::Layer`] which records numbers
//! of bytes read and written. Metrics are collected by
//! [`instrumenting`](`macro@tracing::instrument`) spans with the
//! `cryprot_metrics` target and a phase. From within these spans, events with
//! the same target can be emitted to track the number of bytes read/written.
//!
//! ```
//! use tracing::{event, instrument, Level};
//!
//! #[instrument(target = "cryprot_metrics", fields(phase = "Online"))]
//! async fn online() {
//!     event!(target: "cryprot_metrics", Level::TRACE, bytes_written = 5);
//!     interleaved_setup().await
//! }
//!
//! #[instrument(target = "cryprot_metrics", fields(phase = "Setup"))]
//! async fn interleaved_setup() {
//!     // Will be recorded in the sub phase "Setup" of the online phase
//!     event!(target: "cryprot_metrics", Level::TRACE, bytes_written = 10);
//! }
//! ```
use std::{
    collections::{BTreeMap, btree_map::Entry},
    fmt::Debug,
    mem,
    ops::AddAssign,
    sync::{Arc, Mutex},
};

use serde::{Deserialize, Serialize};
use tracing::{
    Level,
    field::{Field, Visit},
    span::{Attributes, Id},
    warn,
};
use tracing_subscriber::{
    filter::{Filtered, Targets},
    layer::{Context, Layer},
};

#[derive(Debug, Default, Clone, Serialize, Deserialize)]
/// Communication metrics for a phase and its sub phases.
pub struct CommData {
    pub phase: String,
    pub read: Counter,
    pub write: Counter,
    pub sub_comm_data: SubCommData,
}

#[derive(Debug, Default, Clone, Copy, Serialize, Deserialize)]
pub struct Counter {
    /// Number of written/read directly in this phase.
    pub bytes: u64,
    /// Total number of bytes written/read in this phase an all sub phases.
    pub bytes_with_sub_comm: u64,
}

#[derive(Debug, Default, Clone, Serialize, Deserialize)]
/// Sub communication data for different phases
pub struct SubCommData(BTreeMap<String, CommData>);

/// Convenience type alias for a filtered `CommLayerData` which only handles
/// spans and events with `target = "cryprot_metrics"`.
pub type CommLayer<S> = Filtered<CommLayerData, Targets, S>;

#[derive(Clone, Debug, Default)]
/// The `CommLayerData` has shared ownership of the root [`SubCommData`].
pub struct CommLayerData {
    // TOOD use Atomics in SubCommData to not need lock, maybe?
    comm_data: Arc<Mutex<SubCommData>>,
}

/// Instantiate a new [`CommLayer`] and corresponding [`CommLayerData`].
pub fn new_comm_layer<S>() -> (CommLayer<S>, CommLayerData)
where
    S: tracing::Subscriber,
    S: for<'lookup> tracing_subscriber::registry::LookupSpan<'lookup>,
{
    let inner = CommLayerData::default();
    let target_filter = Targets::new().with_target("cryprot_metrics", Level::TRACE);
    (inner.clone().with_filter(target_filter), inner)
}

impl CommLayerData {
    /// Returns a clone of the root `SubCommData` at this moment.
    pub fn comm_data(&self) -> SubCommData {
        self.comm_data.lock().expect("lock poisoned").clone()
    }

    /// Resets the root `SubCommData` and returns it.
    ///
    /// Do not use this method while an instrumented `target = cryprot_metrics`
    /// span is active, as this will result in inconsistent data.
    pub fn reset(&self) -> SubCommData {
        let mut comm_data = self.comm_data.lock().expect("lock poisoned");
        mem::take(&mut *comm_data)
    }
}

impl<S> Layer<S> for CommLayerData
where
    S: tracing::Subscriber,
    S: for<'lookup> tracing_subscriber::registry::LookupSpan<'lookup>,
{
    fn on_new_span(&self, attrs: &Attributes<'_>, id: &Id, ctx: Context<'_, S>) {
        let span = ctx.span(id).expect("Id is valid");
        let mut visitor = PhaseVisitor(None);
        attrs.record(&mut visitor);
        if let Some(phase) = visitor.0 {
            let data = CommData::new(phase);
            span.extensions_mut().insert(data);
        }
    }

    fn on_event(&self, event: &tracing::Event<'_>, ctx: Context<'_, S>) {
        let Some(span) = ctx.event_span(event) else {
            warn!(
                "Received cryprot_metrics event outside of cryprot_metrics span. \
                Communication is not tracked"
            );
            return;
        };
        // Check that we only have one field per event, otherwise the CommEventVisitor
        // will only record on of them
        let field_cnt = event
            .fields()
            .filter(|field| field.name() == "bytes_read" || field.name() == "bytes_written")
            .count();
        if field_cnt >= 2 {
            warn!("Use individual events to record bytes_read and bytes_written");
            return;
        }
        let mut vis = CommEventVisitor(None);
        event.record(&mut vis);
        if let Some(event) = vis.0 {
            let mut extensions = span.extensions_mut();
            let Some(comm_data) = extensions.get_mut::<CommData>() else {
                warn!(
                    "Received cryprot_metrics event inside cryprot_metrics span with no phase. \
                    Communication is not tracked"
                );
                return;
            };
            match event {
                CommEvent::Read(read) => {
                    comm_data.read += read;
                }
                CommEvent::Write(written) => {
                    comm_data.write += written;
                }
            }
        }
    }

    fn on_close(&self, id: Id, ctx: Context<'_, S>) {
        let span = ctx.span(&id).expect("Id is valid");
        let mut extensions = span.extensions_mut();
        let Some(comm_data) = extensions.get_mut::<CommData>().map(mem::take) else {
            // nothing to do
            return;
        };

        // TODO can merging of comm data be done in a background thread? Benchmark
        // first!
        if let Some(parent) = span.parent() {
            if let Some(parent_comm_data) = parent.extensions_mut().get_mut::<CommData>() {
                let entry = parent_comm_data
                    .sub_comm_data
                    .0
                    .entry(comm_data.phase.clone())
                    .or_insert_with(|| CommData::new(comm_data.phase.clone()));
                parent_comm_data.read.bytes_with_sub_comm += comm_data.read.bytes_with_sub_comm;
                parent_comm_data.write.bytes_with_sub_comm += comm_data.write.bytes_with_sub_comm;
                merge(comm_data, entry)
            }
        } else {
            let mut root_comm_data = self.comm_data.lock().expect("lock poisoned");
            let phase_comm_data = root_comm_data
                .0
                .entry(comm_data.phase.clone())
                .or_insert_with(|| CommData::new(comm_data.phase.clone()));
            merge(comm_data, phase_comm_data);
        }
    }
}

fn merge(from: CommData, into: &mut CommData) {
    into.read += from.read;
    into.write += from.write;
    for (phase, from_sub_comm) in from.sub_comm_data.0.into_iter() {
        match into.sub_comm_data.0.entry(phase) {
            Entry::Vacant(entry) => {
                entry.insert(from_sub_comm);
            }
            Entry::Occupied(mut entry) => {
                merge(from_sub_comm, entry.get_mut());
            }
        }
    }
}

impl SubCommData {
    /// Get the [`CommData`] for a phase.
    pub fn get(&self, phase: &str) -> Option<&CommData> {
        self.0.get(phase)
    }

    /// Iterate over all [`CommData`].
    pub fn iter(&self) -> impl Iterator<Item = &CommData> {
        self.0.values()
    }
}

impl AddAssign for Counter {
    fn add_assign(&mut self, rhs: Self) {
        self.bytes += rhs.bytes;
        self.bytes_with_sub_comm += rhs.bytes_with_sub_comm;
    }
}

impl AddAssign<u64> for Counter {
    fn add_assign(&mut self, rhs: u64) {
        self.bytes += rhs;
        self.bytes_with_sub_comm += rhs;
    }
}

impl CommData {
    fn new(phase: String) -> Self {
        Self {
            phase,
            ..Default::default()
        }
    }
}

struct PhaseVisitor(Option<String>);

impl Visit for PhaseVisitor {
    fn record_str(&mut self, field: &Field, value: &str) {
        if field.name() == "phase" {
            self.0 = Some(value.to_owned());
        }
    }

    fn record_debug(&mut self, field: &Field, value: &dyn Debug) {
        if field.name() == "phase" {
            self.0 = Some(format!("{value:?}"));
        }
    }
}

enum CommEvent {
    Read(u64),
    Write(u64),
}

struct CommEventVisitor(Option<CommEvent>);

impl CommEventVisitor {
    fn record<T>(&mut self, field: &Field, value: T)
    where
        T: TryInto<u64>,
        T::Error: Debug,
    {
        let name = field.name();
        if name != "bytes_written" && name != "bytes_read" {
            return;
        }
        let value = value
            .try_into()
            .expect("recorded bytes must be convertible to u64");
        if name == "bytes_written" {
            self.0 = Some(CommEvent::Write(value))
        } else if name == "bytes_read" {
            self.0 = Some(CommEvent::Read(value))
        }
    }
}

impl Visit for CommEventVisitor {
    fn record_i64(&mut self, field: &Field, value: i64) {
        self.record(field, value);
    }
    fn record_u64(&mut self, field: &Field, value: u64) {
        self.record(field, value)
    }
    fn record_i128(&mut self, field: &Field, value: i128) {
        self.record(field, value)
    }
    fn record_u128(&mut self, field: &Field, value: u128) {
        self.record(field, value)
    }
    fn record_debug(&mut self, field: &Field, value: &dyn Debug) {
        warn!(
            "cryprot_metrics event with field which is not an integer. {}: {:?}",
            field.name(),
            value
        )
    }
}

#[cfg(test)]
mod tests {
    use std::time::Duration;

    use tokio::{self, join, time::sleep};
    use tracing::{Instrument, Level, event, instrument};
    use tracing_subscriber::{Registry, layer::SubscriberExt};

    use crate::metrics::new_comm_layer;

    #[tokio::test]
    async fn test_communication_metrics() {
        #[instrument(target = "cryprot_metrics", fields(phase = "TopLevel"))]
        async fn top_level_operation() {
            // Simulate some direct communication
            event!(target: "cryprot_metrics", Level::TRACE, bytes_read = 100);
            event!(target: "cryprot_metrics", Level::TRACE, bytes_written = 200);

            // Call sub-operation
            sub_operation().await;
        }

        #[instrument(target = "cryprot_metrics", fields(phase = "SubOperation"))]
        async fn sub_operation() {
            // Simulate some communication in the sub-operation
            event!(target: "cryprot_metrics", Level::TRACE, bytes_read = 50);
            event!(target: "cryprot_metrics", Level::TRACE, bytes_written = 100);
        }

        // Set up the metrics layer
        let (comm_layer, comm_data) = new_comm_layer();
        let subscriber = Registry::default().with(comm_layer);
        let _guard = tracing::subscriber::set_default(subscriber);

        // Run instrumented functions
        top_level_operation().await;

        // Verify metrics
        let metrics = comm_data.comm_data();

        // Check top level metrics
        let top_phase = metrics
            .get("TopLevel")
            .expect("TopLevel phase should exist");
        assert_eq!(top_phase.phase, "TopLevel");
        assert_eq!(top_phase.read.bytes, 100);
        assert_eq!(top_phase.write.bytes, 200);
        assert_eq!(top_phase.read.bytes_with_sub_comm, 150); // 100 (direct) + 50 (from sub)
        assert_eq!(top_phase.write.bytes_with_sub_comm, 300); // 200 (direct) + 100 (from sub)

        // Check sub-phase metrics
        let sub_phase = top_phase
            .sub_comm_data
            .get("SubOperation")
            .expect("SubOperation phase should exist");
        assert_eq!(sub_phase.phase, "SubOperation");
        assert_eq!(sub_phase.read.bytes, 50);
        assert_eq!(sub_phase.write.bytes, 100);
        assert_eq!(sub_phase.read.bytes_with_sub_comm, 50);
        assert_eq!(sub_phase.write.bytes_with_sub_comm, 100);

        // Reset metrics and verify they're cleared
        let reset_metrics = comm_data.reset();
        assert!(reset_metrics.get("TopLevel").is_some());
        let new_metrics = comm_data.comm_data();
        assert!(new_metrics.get("TopLevel").is_none());
    }

    #[tokio::test]
    async fn test_parallel_span_accumulation() {
        #[instrument(target = "cryprot_metrics", fields(phase = "ParentPhase"))]
        async fn parallel_operation(id: u32) {
            // If communication of a sub-phase happens in a spawned task, the future needs
            // to be instrumented with the current span to preserve hierarchy
            tokio::spawn(sub_operation(id).in_current_span())
                .await
                .unwrap();
        }

        #[instrument(target = "cryprot_metrics", fields(phase = "SubPhase"))]
        async fn sub_operation(id: u32) {
            // Each sub-operation does some communication
            event!(
                target: "cryprot_metrics",
                Level::TRACE,
                bytes_written = 100,
            );
            event!(
                target: "cryprot_metrics",
                Level::TRACE,
                bytes_read = 50
            );
            // Simulate some work to increase chance of overlap
            sleep(Duration::from_millis(10)).await;
        }

        // Set up the metrics layer
        let (comm_layer, comm_data) = new_comm_layer();
        let subscriber = Registry::default().with(comm_layer);
        let _guard = tracing::subscriber::set_default(subscriber);

        // Run parallel operations
        join!(parallel_operation(1), parallel_operation(2));

        // Verify metrics
        let metrics = comm_data.comm_data();
        let phase = metrics
            .get("ParentPhase")
            .expect("ParentPhase should exist");

        // The sub-phase metrics should accumulate from both parallel operations
        let sub_phase = phase
            .sub_comm_data
            .get("SubPhase")
            .expect("SubPhase should exist");

        // Each parallel operation writes 100 bytes in the sub-phase
        // So we expect 200 total bytes written in the sub-phase
        assert_eq!(
            sub_phase.write.bytes, 200,
            "Expected accumulated writes from both parallel operations"
        );

        // Each parallel operation reads 50 bytes in the sub-phase
        // So we expect 100 total bytes read in the sub-phase
        assert_eq!(
            sub_phase.read.bytes, 100,
            "Expected accumulated reads from both parallel operations"
        );

        // Parent phase should accumulate all sub-phase metrics
        assert_eq!(
            phase.write.bytes_with_sub_comm, 200,
            "Parent should include all sub-phase writes"
        );
        assert_eq!(
            phase.read.bytes_with_sub_comm, 100,
            "Parent should include all sub-phase reads"
        );
    }
}