nominal_streaming/
stream.rs

1use std::collections::HashMap;
2use std::fmt::Debug;
3use std::sync::atomic::AtomicBool;
4use std::sync::atomic::AtomicU64;
5use std::sync::atomic::AtomicUsize;
6use std::sync::atomic::Ordering;
7use std::sync::Arc;
8use std::thread;
9use std::time::Duration;
10use std::time::Instant;
11use std::time::UNIX_EPOCH;
12
13use conjure_object::BearerToken;
14use nominal_api::tonic::io::nominal::scout::api::proto::points::PointsType;
15use nominal_api::tonic::io::nominal::scout::api::proto::Channel;
16use nominal_api::tonic::io::nominal::scout::api::proto::DoublePoint;
17use nominal_api::tonic::io::nominal::scout::api::proto::DoublePoints;
18use nominal_api::tonic::io::nominal::scout::api::proto::IntegerPoint;
19use nominal_api::tonic::io::nominal::scout::api::proto::IntegerPoints;
20use nominal_api::tonic::io::nominal::scout::api::proto::Points;
21use nominal_api::tonic::io::nominal::scout::api::proto::Series;
22use nominal_api::tonic::io::nominal::scout::api::proto::StringPoint;
23use nominal_api::tonic::io::nominal::scout::api::proto::StringPoints;
24use nominal_api::tonic::io::nominal::scout::api::proto::WriteRequestNominal;
25use parking_lot::Condvar;
26use parking_lot::Mutex;
27use parking_lot::MutexGuard;
28use tracing::debug;
29use tracing::error;
30use tracing::info;
31use tracing::warn;
32
33use crate::consumer::WriteRequestConsumer;
34
35#[derive(Clone, Debug, Eq, Hash, PartialEq, Ord, PartialOrd)]
36pub struct ChannelName(String);
37
38#[derive(Clone, Debug, Eq, Hash, PartialEq, Ord, PartialOrd)]
39struct ChannelKey(ChannelName, Vec<(String, String)>);
40
41impl From<&str> for ChannelName {
42    fn from(name: &str) -> Self {
43        Self(name.to_string())
44    }
45}
46
47pub trait AuthProvider: Clone + Send + Sync {
48    fn token(&self) -> Option<BearerToken>;
49}
50
51pub trait IntoPoints {
52    fn into_points(self) -> PointsType;
53}
54
55impl IntoPoints for PointsType {
56    fn into_points(self) -> PointsType {
57        self
58    }
59}
60
61impl IntoPoints for Vec<DoublePoint> {
62    fn into_points(self) -> PointsType {
63        PointsType::DoublePoints(DoublePoints { points: self })
64    }
65}
66
67impl IntoPoints for Vec<StringPoint> {
68    fn into_points(self) -> PointsType {
69        PointsType::StringPoints(StringPoints { points: self })
70    }
71}
72
73impl IntoPoints for Vec<IntegerPoint> {
74    fn into_points(self) -> PointsType {
75        PointsType::IntegerPoints(IntegerPoints { points: self })
76    }
77}
78
79#[derive(Debug, Clone)]
80pub struct NominalStreamOpts {
81    pub max_points_per_record: usize,
82    pub max_request_delay: Duration,
83    pub max_buffered_requests: usize,
84    pub request_dispatcher_tasks: usize,
85}
86
87impl Default for NominalStreamOpts {
88    fn default() -> Self {
89        Self {
90            max_points_per_record: 250_000,
91            max_request_delay: Duration::from_millis(100),
92            max_buffered_requests: 4,
93            request_dispatcher_tasks: 8,
94        }
95    }
96}
97
98pub struct NominalDatasourceStream {
99    running: Arc<AtomicBool>,
100    unflushed_points: Arc<AtomicUsize>,
101    primary_buffer: Arc<SeriesBuffer>,
102    secondary_buffer: Arc<SeriesBuffer>,
103    primary_handle: thread::JoinHandle<()>,
104    secondary_handle: thread::JoinHandle<()>,
105}
106
107impl NominalDatasourceStream {
108    pub fn new_with_consumer<C: WriteRequestConsumer + 'static>(
109        consumer: C,
110        opts: NominalStreamOpts,
111    ) -> Self {
112        let primary_buffer = Arc::new(SeriesBuffer::new(opts.max_points_per_record));
113        let secondary_buffer = Arc::new(SeriesBuffer::new(opts.max_points_per_record));
114
115        let (request_tx, request_rx) =
116            crossbeam_channel::bounded::<(WriteRequestNominal, usize)>(opts.max_buffered_requests);
117
118        let running = Arc::new(AtomicBool::new(true));
119        let unflushed_points = Arc::new(AtomicUsize::new(0));
120
121        let primary_handle = thread::Builder::new()
122            .name("nmstream_primary".to_string())
123            .spawn({
124                let points_buffer = Arc::clone(&primary_buffer);
125                let running = running.clone();
126                let tx = request_tx.clone();
127                move || {
128                    batch_processor(running, points_buffer, tx, opts.max_request_delay);
129                }
130            })
131            .unwrap();
132
133        let secondary_handle = thread::Builder::new()
134            .name("nmstream_secondary".to_string())
135            .spawn({
136                let secondary_buffer = Arc::clone(&secondary_buffer);
137                let running = running.clone();
138                move || {
139                    batch_processor(
140                        running,
141                        secondary_buffer,
142                        request_tx,
143                        opts.max_request_delay,
144                    );
145                }
146            })
147            .unwrap();
148
149        let consumer = Arc::new(consumer);
150
151        for i in 0..opts.request_dispatcher_tasks {
152            thread::Builder::new()
153                .name(format!("nmstream_dispatch_{i}"))
154                .spawn({
155                    let running = Arc::clone(&running);
156                    let unflushed_points = Arc::clone(&unflushed_points);
157                    let rx = request_rx.clone();
158                    let consumer = consumer.clone();
159                    move || {
160                        debug!("starting request dispatcher");
161                        request_dispatcher(running, unflushed_points, rx, consumer);
162                    }
163                })
164                .unwrap();
165        }
166
167        NominalDatasourceStream {
168            running,
169            unflushed_points,
170            primary_buffer,
171            secondary_buffer,
172            primary_handle,
173            secondary_handle,
174        }
175    }
176
177    pub fn enqueue(
178        &self,
179        channel: ChannelName,
180        tags: HashMap<String, String>,
181        new_points: impl IntoPoints,
182    ) {
183        let new_points = new_points.into_points();
184        let new_count = points_len(&new_points);
185
186        self.unflushed_points
187            .fetch_add(new_count, Ordering::Release);
188
189        if self.primary_buffer.has_capacity(new_count) {
190            debug!("adding {} points to primary buffer", new_count);
191            self.primary_buffer.add_points(channel, tags, new_points);
192        } else if self.secondary_buffer.has_capacity(new_count) {
193            // primary buffer is definitely full
194            self.primary_handle.thread().unpark();
195            debug!("adding {} points to secondary buffer", new_count);
196            self.secondary_buffer.add_points(channel, tags, new_points);
197        } else {
198            warn!("both buffers are full, picking least recently flushed buffer to add to");
199            // both buffers are full - wait on the buffer that flushed least recently (i.e more
200            // likely that it's nearly done)
201            let buf = if self.primary_buffer < self.secondary_buffer {
202                debug!("waiting for primary buffer to flush...");
203                self.primary_handle.thread().unpark();
204                &self.primary_buffer
205            } else {
206                debug!("waiting for secondary buffer to flush...");
207                self.secondary_handle.thread().unpark();
208                &self.secondary_buffer
209            };
210            buf.add_on_notify(channel, tags, new_points);
211            debug!("added points after wait to chosen buffer")
212        }
213    }
214}
215
216struct SeriesBuffer {
217    points: Mutex<HashMap<ChannelKey, PointsType>>,
218    count: AtomicUsize,
219    flush_time: AtomicU64,
220    condvar: Condvar,
221    max_capacity: usize,
222}
223
224impl PartialEq for SeriesBuffer {
225    fn eq(&self, other: &Self) -> bool {
226        self.flush_time.load(Ordering::Acquire) == other.flush_time.load(Ordering::Acquire)
227    }
228}
229
230impl PartialOrd for SeriesBuffer {
231    fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
232        let flush_time = self.flush_time.load(Ordering::Acquire);
233        let other_flush_time = other.flush_time.load(Ordering::Acquire);
234        flush_time.partial_cmp(&other_flush_time)
235    }
236}
237
238impl SeriesBuffer {
239    fn new(capacity: usize) -> Self {
240        Self {
241            points: Mutex::new(HashMap::new()),
242            count: AtomicUsize::new(0),
243            flush_time: AtomicU64::new(0),
244            condvar: Condvar::new(),
245            max_capacity: capacity,
246        }
247    }
248
249    /// Checks if the buffer has enough capacity to add new points.
250    /// Note that the buffer can be larger than MAX_POINTS_PER_RECORD if a single batch of points
251    /// larger than MAX_POINTS_PER_RECORD is inserted while the buffer is empty. This avoids needing
252    /// to handle splitting batches of points across multiple requests.
253    fn has_capacity(&self, new_points_count: usize) -> bool {
254        let count = self.count.load(Ordering::Acquire);
255        count == 0 || count + new_points_count <= self.max_capacity
256    }
257
258    fn add_points(
259        &self,
260        channel: ChannelName,
261        tags: HashMap<String, String>,
262        new_points: PointsType,
263    ) {
264        self.inner_add_points(channel, tags, new_points, self.points.lock());
265    }
266
267    fn inner_add_points(
268        &self,
269        channel: ChannelName,
270        tags: HashMap<String, String>,
271        new_points: PointsType,
272        mut points_guard: MutexGuard<HashMap<ChannelKey, PointsType>>,
273    ) {
274        let mut tags = tags.into_iter().collect::<Vec<_>>();
275        tags.sort();
276        self.count
277            .fetch_add(points_len(&new_points), Ordering::Release);
278        let key = ChannelKey(channel, tags);
279        match (points_guard.get_mut(&key), new_points) {
280            (None, new_points) => {
281                points_guard.insert(key, new_points);
282            }
283            (Some(PointsType::DoublePoints(points)), PointsType::DoublePoints(new_points)) => {
284                points
285                    .points
286                    .extend_from_slice(new_points.points.as_slice());
287            }
288            (Some(PointsType::StringPoints(points)), PointsType::StringPoints(new_points)) => {
289                points
290                    .points
291                    .extend_from_slice(new_points.points.as_slice());
292            }
293            (Some(PointsType::IntegerPoints(points)), PointsType::IntegerPoints(new_points)) => {
294                points
295                    .points
296                    .extend_from_slice(new_points.points.as_slice());
297            }
298            (Some(PointsType::DoublePoints(_)), PointsType::StringPoints(_)) => {
299                // todo: return an error instead of panicking
300                panic!(
301                    "attempting to add points of the wrong type to an existing channel. expected: double. provided: string"
302                )
303            }
304            (Some(PointsType::DoublePoints(_)), PointsType::IntegerPoints(_)) => {
305                // todo: return an error instead of panicking
306                panic!(
307                    "attempting to add points of the wrong type to an existing channel. expected: double. provided: string"
308                )
309            }
310            (Some(PointsType::StringPoints(_)), PointsType::DoublePoints(_)) => {
311                // todo: return an error instead of panicking
312                panic!(
313                    "attempting to add points of the wrong type to an existing channel. expected: string. provided: double"
314                )
315            }
316            (Some(PointsType::StringPoints(_)), PointsType::IntegerPoints(_)) => {
317                // todo: return an error instead of panicking
318                panic!(
319                    "attempting to add points of the wrong type to an existing channel. expected: string. provided: double"
320                )
321            }
322            (Some(PointsType::IntegerPoints(_)), PointsType::DoublePoints(_)) => {
323                // todo: return an error instead of panicking
324                panic!(
325                    "attempting to add points of the wrong type to an existing channel. expected: string. provided: double"
326                )
327            }
328            (Some(PointsType::IntegerPoints(_)), PointsType::StringPoints(_)) => {
329                // todo: return an error instead of panicking
330                panic!(
331                    "attempting to add points of the wrong type to an existing channel. expected: string. provided: double"
332                )
333            }
334        }
335    }
336
337    fn take(&self) -> (usize, Vec<Series>) {
338        let mut points = self.points.lock();
339        self.flush_time.store(
340            UNIX_EPOCH.elapsed().unwrap().as_nanos() as u64,
341            Ordering::Release,
342        );
343        let result = points
344            .drain()
345            .map(|(ChannelKey(channel_name, tags), points)| {
346                let channel = Channel {
347                    name: channel_name.0,
348                };
349                let points_obj = Points {
350                    points_type: Some(points),
351                };
352                Series {
353                    channel: Some(channel),
354                    tags: tags.into_iter().collect(),
355                    points: Some(points_obj),
356                }
357            })
358            .collect();
359        let result_count = self
360            .count
361            .fetch_update(Ordering::Release, Ordering::Acquire, |_| Some(0))
362            .unwrap();
363        (result_count, result)
364    }
365
366    fn is_empty(&self) -> bool {
367        self.count() == 0
368    }
369
370    fn count(&self) -> usize {
371        self.count.load(Ordering::Acquire)
372    }
373
374    fn add_on_notify(
375        &self,
376        channel: ChannelName,
377        tags: HashMap<String, String>,
378        new_points: PointsType,
379    ) {
380        let mut points_lock = self.points.lock();
381        // concurrency bug without this - the buffer could have been emptied since we
382        // checked the count, so this will wait forever & block any new points from entering
383        if !points_lock.is_empty() {
384            self.condvar.wait(&mut points_lock);
385        } else {
386            debug!("buffer emptied since last check, skipping condvar wait");
387        }
388        self.inner_add_points(channel, tags, new_points, points_lock);
389    }
390
391    fn notify(&self) -> bool {
392        self.condvar.notify_one()
393    }
394}
395
396fn batch_processor(
397    running: Arc<AtomicBool>,
398    points_buffer: Arc<SeriesBuffer>,
399    request_chan: crossbeam_channel::Sender<(WriteRequestNominal, usize)>,
400    max_request_delay: Duration,
401) {
402    loop {
403        debug!("starting processor loop");
404        if points_buffer.is_empty() {
405            if !running.load(Ordering::Acquire) {
406                debug!("batch processor thread exiting due to running flag");
407                drop(request_chan);
408                break;
409            } else {
410                debug!("empty points buffer, waiting");
411                thread::park_timeout(max_request_delay);
412            }
413            continue;
414        }
415        let (point_count, series) = points_buffer.take();
416
417        if points_buffer.notify() {
418            debug!("notified one waiting thread after clearing points buffer");
419        }
420
421        let write_request = WriteRequestNominal { series };
422
423        if request_chan.is_full() {
424            warn!("request channel is full");
425        }
426        let rep = request_chan.send((write_request, point_count));
427        debug!("queued request for processing");
428        if rep.is_err() {
429            error!("failed to send request to dispatcher");
430        } else {
431            debug!("finished submitting request");
432        }
433
434        thread::park_timeout(max_request_delay);
435    }
436    debug!("batch processor thread exiting");
437}
438
439impl Drop for NominalDatasourceStream {
440    fn drop(&mut self) {
441        debug!("starting drop for NominalDatasourceStream");
442        self.running.store(false, Ordering::Release);
443        while self.unflushed_points.load(Ordering::Acquire) > 0 {
444            debug!(
445                "waiting for all points to be flushed before dropping stream, {} points remaining",
446                self.unflushed_points.load(Ordering::Acquire)
447            );
448            // todo: reduce this + give up after some maximum timeout is reached
449            thread::sleep(Duration::from_millis(50));
450        }
451    }
452}
453
454fn request_dispatcher<C: WriteRequestConsumer + 'static>(
455    running: Arc<AtomicBool>,
456    unflushed_points: Arc<AtomicUsize>,
457    request_rx: crossbeam_channel::Receiver<(WriteRequestNominal, usize)>,
458    consumer: Arc<C>,
459) {
460    let mut total_request_time = 0;
461    loop {
462        match request_rx.recv() {
463            Ok((request, point_count)) => {
464                debug!("received writerequest from channel");
465                let req_start = Instant::now();
466                match consumer.consume(&request) {
467                    Ok(_) => {
468                        let time = req_start.elapsed().as_millis();
469                        debug!("request of {} points sent in {} ms", point_count, time);
470                        total_request_time += time as u64;
471                    }
472                    Err(e) => {
473                        error!("Failed to send request: {e:?}");
474                    }
475                }
476                unflushed_points.fetch_sub(point_count, Ordering::Release);
477
478                if unflushed_points.load(Ordering::Acquire) == 0 && !running.load(Ordering::Acquire)
479                {
480                    info!("all points flushed, closing dispatcher thread");
481                    // notify the processor thread that all points have been flushed
482                    drop(request_rx);
483                    break;
484                }
485            }
486            Err(e) => {
487                debug!("request channel closed, exiting dispatcher thread. info: '{e}'");
488                break;
489            }
490        }
491    }
492    debug!(
493        "request dispatcher thread exiting. total request time: {}",
494        total_request_time
495    );
496}
497
498fn points_len(points_type: &PointsType) -> usize {
499    match points_type {
500        PointsType::DoublePoints(points) => points.points.len(),
501        PointsType::StringPoints(points) => points.points.len(),
502        PointsType::IntegerPoints(points) => points.points.len(),
503    }
504}