nominal_streaming/
stream.rs

1use std::collections::BTreeMap;
2use std::collections::HashMap;
3use std::fmt::Debug;
4use std::sync::atomic::AtomicBool;
5use std::sync::atomic::AtomicU64;
6use std::sync::atomic::AtomicUsize;
7use std::sync::atomic::Ordering;
8use std::sync::Arc;
9use std::thread;
10use std::time::Duration;
11use std::time::Instant;
12use std::time::UNIX_EPOCH;
13
14use conjure_object::BearerToken;
15use nominal_api::tonic::io::nominal::scout::api::proto::points::PointsType;
16use nominal_api::tonic::io::nominal::scout::api::proto::Channel;
17use nominal_api::tonic::io::nominal::scout::api::proto::DoublePoint;
18use nominal_api::tonic::io::nominal::scout::api::proto::DoublePoints;
19use nominal_api::tonic::io::nominal::scout::api::proto::IntegerPoint;
20use nominal_api::tonic::io::nominal::scout::api::proto::IntegerPoints;
21use nominal_api::tonic::io::nominal::scout::api::proto::Points;
22use nominal_api::tonic::io::nominal::scout::api::proto::Series;
23use nominal_api::tonic::io::nominal::scout::api::proto::StringPoint;
24use nominal_api::tonic::io::nominal::scout::api::proto::StringPoints;
25use nominal_api::tonic::io::nominal::scout::api::proto::WriteRequestNominal;
26use parking_lot::Condvar;
27use parking_lot::Mutex;
28use parking_lot::MutexGuard;
29use tracing::debug;
30use tracing::error;
31use tracing::info;
32use tracing::warn;
33
34use crate::consumer::WriteRequestConsumer;
35
36/// A descriptor for a channel.
37///
38/// Note that this is used internally to compare channels.
39#[derive(Clone, Debug, Eq, Hash, PartialEq, Ord, PartialOrd)]
40pub struct ChannelDescriptor {
41    /// The name of the channel.
42    pub name: String,
43    /// The tags associated with the channel.
44    pub tags: BTreeMap<String, String>,
45}
46
47impl ChannelDescriptor {
48    pub fn new(
49        name: impl Into<String>,
50        tags: impl IntoIterator<Item = (impl Into<String>, impl Into<String>)>,
51    ) -> Self {
52        Self {
53            name: name.into(),
54            tags: tags
55                .into_iter()
56                .map(|(key, value)| (key.into(), value.into()))
57                .collect(),
58        }
59    }
60}
61
62pub trait AuthProvider: Clone + Send + Sync {
63    fn token(&self) -> Option<BearerToken>;
64}
65
66pub trait IntoPoints {
67    fn into_points(self) -> PointsType;
68}
69
70impl IntoPoints for PointsType {
71    fn into_points(self) -> PointsType {
72        self
73    }
74}
75
76impl IntoPoints for Vec<DoublePoint> {
77    fn into_points(self) -> PointsType {
78        PointsType::DoublePoints(DoublePoints { points: self })
79    }
80}
81
82impl IntoPoints for Vec<StringPoint> {
83    fn into_points(self) -> PointsType {
84        PointsType::StringPoints(StringPoints { points: self })
85    }
86}
87
88impl IntoPoints for Vec<IntegerPoint> {
89    fn into_points(self) -> PointsType {
90        PointsType::IntegerPoints(IntegerPoints { points: self })
91    }
92}
93
94#[derive(Debug, Clone)]
95pub struct NominalStreamOpts {
96    pub max_points_per_record: usize,
97    pub max_request_delay: Duration,
98    pub max_buffered_requests: usize,
99    pub request_dispatcher_tasks: usize,
100}
101
102impl Default for NominalStreamOpts {
103    fn default() -> Self {
104        Self {
105            max_points_per_record: 250_000,
106            max_request_delay: Duration::from_millis(100),
107            max_buffered_requests: 4,
108            request_dispatcher_tasks: 8,
109        }
110    }
111}
112
113pub struct NominalDatasourceStream {
114    running: Arc<AtomicBool>,
115    unflushed_points: Arc<AtomicUsize>,
116    primary_buffer: Arc<SeriesBuffer>,
117    secondary_buffer: Arc<SeriesBuffer>,
118    primary_handle: thread::JoinHandle<()>,
119    secondary_handle: thread::JoinHandle<()>,
120}
121
122impl NominalDatasourceStream {
123    pub fn new_with_consumer<C: WriteRequestConsumer + 'static>(
124        consumer: C,
125        opts: NominalStreamOpts,
126    ) -> Self {
127        let primary_buffer = Arc::new(SeriesBuffer::new(opts.max_points_per_record));
128        let secondary_buffer = Arc::new(SeriesBuffer::new(opts.max_points_per_record));
129
130        let (request_tx, request_rx) =
131            crossbeam_channel::bounded::<(WriteRequestNominal, usize)>(opts.max_buffered_requests);
132
133        let running = Arc::new(AtomicBool::new(true));
134        let unflushed_points = Arc::new(AtomicUsize::new(0));
135
136        let primary_handle = thread::Builder::new()
137            .name("nmstream_primary".to_string())
138            .spawn({
139                let points_buffer = Arc::clone(&primary_buffer);
140                let running = running.clone();
141                let tx = request_tx.clone();
142                move || {
143                    batch_processor(running, points_buffer, tx, opts.max_request_delay);
144                }
145            })
146            .unwrap();
147
148        let secondary_handle = thread::Builder::new()
149            .name("nmstream_secondary".to_string())
150            .spawn({
151                let secondary_buffer = Arc::clone(&secondary_buffer);
152                let running = running.clone();
153                move || {
154                    batch_processor(
155                        running,
156                        secondary_buffer,
157                        request_tx,
158                        opts.max_request_delay,
159                    );
160                }
161            })
162            .unwrap();
163
164        let consumer = Arc::new(consumer);
165
166        for i in 0..opts.request_dispatcher_tasks {
167            thread::Builder::new()
168                .name(format!("nmstream_dispatch_{i}"))
169                .spawn({
170                    let running = Arc::clone(&running);
171                    let unflushed_points = Arc::clone(&unflushed_points);
172                    let rx = request_rx.clone();
173                    let consumer = consumer.clone();
174                    move || {
175                        debug!("starting request dispatcher");
176                        request_dispatcher(running, unflushed_points, rx, consumer);
177                    }
178                })
179                .unwrap();
180        }
181
182        NominalDatasourceStream {
183            running,
184            unflushed_points,
185            primary_buffer,
186            secondary_buffer,
187            primary_handle,
188            secondary_handle,
189        }
190    }
191
192    pub fn enqueue(&self, channel_descriptor: &ChannelDescriptor, new_points: impl IntoPoints) {
193        let new_points = new_points.into_points();
194        let new_count = points_len(&new_points);
195
196        self.unflushed_points
197            .fetch_add(new_count, Ordering::Release);
198
199        if self.primary_buffer.has_capacity(new_count) {
200            debug!("adding {} points to primary buffer", new_count);
201            self.primary_buffer
202                .add_points(channel_descriptor, new_points);
203        } else if self.secondary_buffer.has_capacity(new_count) {
204            // primary buffer is definitely full
205            self.primary_handle.thread().unpark();
206            debug!("adding {} points to secondary buffer", new_count);
207            self.secondary_buffer
208                .add_points(channel_descriptor, new_points);
209        } else {
210            warn!("both buffers are full, picking least recently flushed buffer to add to");
211            // both buffers are full - wait on the buffer that flushed least recently (i.e more
212            // likely that it's nearly done)
213            let buf = if self.primary_buffer < self.secondary_buffer {
214                debug!("waiting for primary buffer to flush...");
215                self.primary_handle.thread().unpark();
216                &self.primary_buffer
217            } else {
218                debug!("waiting for secondary buffer to flush...");
219                self.secondary_handle.thread().unpark();
220                &self.secondary_buffer
221            };
222            buf.add_on_notify(channel_descriptor, new_points);
223            debug!("added points after wait to chosen buffer")
224        }
225    }
226}
227
228struct SeriesBuffer {
229    points: Mutex<HashMap<ChannelDescriptor, PointsType>>,
230    count: AtomicUsize,
231    flush_time: AtomicU64,
232    condvar: Condvar,
233    max_capacity: usize,
234}
235
236impl PartialEq for SeriesBuffer {
237    fn eq(&self, other: &Self) -> bool {
238        self.flush_time.load(Ordering::Acquire) == other.flush_time.load(Ordering::Acquire)
239    }
240}
241
242impl PartialOrd for SeriesBuffer {
243    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
244        let flush_time = self.flush_time.load(Ordering::Acquire);
245        let other_flush_time = other.flush_time.load(Ordering::Acquire);
246        flush_time.partial_cmp(&other_flush_time)
247    }
248}
249
250impl SeriesBuffer {
251    fn new(capacity: usize) -> Self {
252        Self {
253            points: Mutex::new(HashMap::new()),
254            count: AtomicUsize::new(0),
255            flush_time: AtomicU64::new(0),
256            condvar: Condvar::new(),
257            max_capacity: capacity,
258        }
259    }
260
261    /// Checks if the buffer has enough capacity to add new points.
262    /// Note that the buffer can be larger than MAX_POINTS_PER_RECORD if a single batch of points
263    /// larger than MAX_POINTS_PER_RECORD is inserted while the buffer is empty. This avoids needing
264    /// to handle splitting batches of points across multiple requests.
265    fn has_capacity(&self, new_points_count: usize) -> bool {
266        let count = self.count.load(Ordering::Acquire);
267        count == 0 || count + new_points_count <= self.max_capacity
268    }
269
270    fn add_points(&self, channel_descriptor: &ChannelDescriptor, new_points: PointsType) {
271        self.inner_add_points(channel_descriptor, new_points, self.points.lock());
272    }
273
274    fn inner_add_points(
275        &self,
276        channel_descriptor: &ChannelDescriptor,
277        new_points: PointsType,
278        mut points_guard: MutexGuard<HashMap<ChannelDescriptor, PointsType>>,
279    ) {
280        self.count
281            .fetch_add(points_len(&new_points), Ordering::Release);
282        match (points_guard.get_mut(channel_descriptor), new_points) {
283            (None, new_points) => {
284                points_guard.insert(channel_descriptor.clone(), new_points);
285            }
286            (Some(PointsType::DoublePoints(points)), PointsType::DoublePoints(new_points)) => {
287                points
288                    .points
289                    .extend_from_slice(new_points.points.as_slice());
290            }
291            (Some(PointsType::StringPoints(points)), PointsType::StringPoints(new_points)) => {
292                points
293                    .points
294                    .extend_from_slice(new_points.points.as_slice());
295            }
296            (Some(PointsType::IntegerPoints(points)), PointsType::IntegerPoints(new_points)) => {
297                points
298                    .points
299                    .extend_from_slice(new_points.points.as_slice());
300            }
301            (Some(PointsType::DoublePoints(_)), PointsType::StringPoints(_)) => {
302                // todo: return an error instead of panicking
303                panic!(
304                    "attempting to add points of the wrong type to an existing channel. expected: double. provided: string"
305                )
306            }
307            (Some(PointsType::DoublePoints(_)), PointsType::IntegerPoints(_)) => {
308                // todo: return an error instead of panicking
309                panic!(
310                    "attempting to add points of the wrong type to an existing channel. expected: double. provided: string"
311                )
312            }
313            (Some(PointsType::StringPoints(_)), PointsType::DoublePoints(_)) => {
314                // todo: return an error instead of panicking
315                panic!(
316                    "attempting to add points of the wrong type to an existing channel. expected: string. provided: double"
317                )
318            }
319            (Some(PointsType::StringPoints(_)), PointsType::IntegerPoints(_)) => {
320                // todo: return an error instead of panicking
321                panic!(
322                    "attempting to add points of the wrong type to an existing channel. expected: string. provided: double"
323                )
324            }
325            (Some(PointsType::IntegerPoints(_)), PointsType::DoublePoints(_)) => {
326                // todo: return an error instead of panicking
327                panic!(
328                    "attempting to add points of the wrong type to an existing channel. expected: string. provided: double"
329                )
330            }
331            (Some(PointsType::IntegerPoints(_)), PointsType::StringPoints(_)) => {
332                // todo: return an error instead of panicking
333                panic!(
334                    "attempting to add points of the wrong type to an existing channel. expected: string. provided: double"
335                )
336            }
337        }
338    }
339
340    fn take(&self) -> (usize, Vec<Series>) {
341        let mut points = self.points.lock();
342        self.flush_time.store(
343            UNIX_EPOCH.elapsed().unwrap().as_nanos() as u64,
344            Ordering::Release,
345        );
346        let result = points
347            .drain()
348            .map(|(ChannelDescriptor { name, tags }, points)| {
349                let channel = Channel { name };
350                let points_obj = Points {
351                    points_type: Some(points),
352                };
353                Series {
354                    channel: Some(channel),
355                    tags: tags.into_iter().collect(),
356                    points: Some(points_obj),
357                }
358            })
359            .collect();
360        let result_count = self
361            .count
362            .fetch_update(Ordering::Release, Ordering::Acquire, |_| Some(0))
363            .unwrap();
364        (result_count, result)
365    }
366
367    fn is_empty(&self) -> bool {
368        self.count() == 0
369    }
370
371    fn count(&self) -> usize {
372        self.count.load(Ordering::Acquire)
373    }
374
375    fn add_on_notify(&self, channel_descriptor: &ChannelDescriptor, new_points: PointsType) {
376        let mut points_lock = self.points.lock();
377        // concurrency bug without this - the buffer could have been emptied since we
378        // checked the count, so this will wait forever & block any new points from entering
379        if !points_lock.is_empty() {
380            self.condvar.wait(&mut points_lock);
381        } else {
382            debug!("buffer emptied since last check, skipping condvar wait");
383        }
384        self.inner_add_points(channel_descriptor, new_points, points_lock);
385    }
386
387    fn notify(&self) -> bool {
388        self.condvar.notify_one()
389    }
390}
391
392fn batch_processor(
393    running: Arc<AtomicBool>,
394    points_buffer: Arc<SeriesBuffer>,
395    request_chan: crossbeam_channel::Sender<(WriteRequestNominal, usize)>,
396    max_request_delay: Duration,
397) {
398    loop {
399        debug!("starting processor loop");
400        if points_buffer.is_empty() {
401            if !running.load(Ordering::Acquire) {
402                debug!("batch processor thread exiting due to running flag");
403                drop(request_chan);
404                break;
405            } else {
406                debug!("empty points buffer, waiting");
407                thread::park_timeout(max_request_delay);
408            }
409            continue;
410        }
411        let (point_count, series) = points_buffer.take();
412
413        if points_buffer.notify() {
414            debug!("notified one waiting thread after clearing points buffer");
415        }
416
417        let write_request = WriteRequestNominal { series };
418
419        if request_chan.is_full() {
420            warn!("request channel is full");
421        }
422        let rep = request_chan.send((write_request, point_count));
423        debug!("queued request for processing");
424        if rep.is_err() {
425            error!("failed to send request to dispatcher");
426        } else {
427            debug!("finished submitting request");
428        }
429
430        thread::park_timeout(max_request_delay);
431    }
432    debug!("batch processor thread exiting");
433}
434
435impl Drop for NominalDatasourceStream {
436    fn drop(&mut self) {
437        debug!("starting drop for NominalDatasourceStream");
438        self.running.store(false, Ordering::Release);
439        while self.unflushed_points.load(Ordering::Acquire) > 0 {
440            debug!(
441                "waiting for all points to be flushed before dropping stream, {} points remaining",
442                self.unflushed_points.load(Ordering::Acquire)
443            );
444            // todo: reduce this + give up after some maximum timeout is reached
445            thread::sleep(Duration::from_millis(50));
446        }
447    }
448}
449
450fn request_dispatcher<C: WriteRequestConsumer + 'static>(
451    running: Arc<AtomicBool>,
452    unflushed_points: Arc<AtomicUsize>,
453    request_rx: crossbeam_channel::Receiver<(WriteRequestNominal, usize)>,
454    consumer: Arc<C>,
455) {
456    let mut total_request_time = 0;
457    loop {
458        match request_rx.recv() {
459            Ok((request, point_count)) => {
460                debug!("received writerequest from channel");
461                let req_start = Instant::now();
462                match consumer.consume(&request) {
463                    Ok(_) => {
464                        let time = req_start.elapsed().as_millis();
465                        debug!("request of {} points sent in {} ms", point_count, time);
466                        total_request_time += time as u64;
467                    }
468                    Err(e) => {
469                        error!("Failed to send request: {e:?}");
470                    }
471                }
472                unflushed_points.fetch_sub(point_count, Ordering::Release);
473
474                if unflushed_points.load(Ordering::Acquire) == 0 && !running.load(Ordering::Acquire)
475                {
476                    info!("all points flushed, closing dispatcher thread");
477                    // notify the processor thread that all points have been flushed
478                    drop(request_rx);
479                    break;
480                }
481            }
482            Err(e) => {
483                debug!("request channel closed, exiting dispatcher thread. info: '{e}'");
484                break;
485            }
486        }
487    }
488    debug!(
489        "request dispatcher thread exiting. total request time: {}",
490        total_request_time
491    );
492}
493
494fn points_len(points_type: &PointsType) -> usize {
495    match points_type {
496        PointsType::DoublePoints(points) => points.points.len(),
497        PointsType::StringPoints(points) => points.points.len(),
498        PointsType::IntegerPoints(points) => points.points.len(),
499    }
500}