braid_triggerbox/
lib.rs

1#[macro_use]
2extern crate log;
3
4mod datetime_conversion;
5
6mod arduino_udev;
7use crate::arduino_udev::serial_handshake;
8
9use anyhow::{Context, Result};
10use chrono::Duration;
11
12use nalgebra as na;
13
14use std::collections::BTreeMap;
15use tokio::{
16    io::{AsyncReadExt, AsyncWriteExt},
17    sync::mpsc::{Receiver, Sender},
18};
19
20use braid_triggerbox_comms::{Prescaler, TopAndPrescaler, DEVICE_FIRMWARE_VERSION};
21
22// ----- name type handling
23pub const DEVICE_NAME_LEN: usize = 8;
24
25pub type InnerNameType = [u8; DEVICE_NAME_LEN];
26pub type NameType = Option<InnerNameType>;
27
28pub type ClockModelCallback = Box<dyn FnMut(Option<ClockModel>) + Send>;
29
30pub fn to_name_type(x: &str) -> anyhow::Result<InnerNameType> {
31    let mut name = [0; DEVICE_NAME_LEN];
32    let bytes = x.as_bytes();
33    if bytes.len() > DEVICE_NAME_LEN {
34        anyhow::bail!("Maximum name length ({} chars) exceeded.", DEVICE_NAME_LEN);
35    }
36    name[..bytes.len()].copy_from_slice(bytes);
37    Ok(name)
38}
39
40pub fn name_display(name: &NameType) -> String {
41    if let Some(name) = name {
42        format!("\"{}\"", String::from_utf8_lossy(name))
43    } else {
44        "none".into()
45    }
46}
47
48// ------ clock model types
49
50#[derive(Debug, PartialEq, Clone)]
51pub struct ClockModel {
52    pub gain: f64,
53    pub offset: f64,
54    pub residuals: f64,
55    pub n_measurements: u64,
56}
57
58#[derive(Debug)]
59pub struct TriggerClockInfoRow {
60    // changes to this should update BraidMetadataSchemaTag
61    pub start_timestamp: chrono::DateTime<chrono::Utc>,
62    pub framecount: i64,
63    pub tcnt: u8,
64    pub stop_timestamp: chrono::DateTime<chrono::Utc>,
65}
66
67/// A Braid Triggerbox device.
68pub struct TriggerboxDevice {
69    icr1_and_prescaler: Option<TopAndPrescaler>,
70    version_check_done: bool,
71    qi: u8,
72    queries: BTreeMap<u8, chrono::DateTime<chrono::Utc>>,
73    ser: tokio_serial::SerialStream,
74    outq: Receiver<Cmd>,
75    vquery_time: chrono::DateTime<chrono::Utc>,
76    last_time: chrono::DateTime<chrono::Utc>,
77    past_data: Vec<(f64, f64)>,
78    allow_requesting_clock_sync: bool,
79    on_new_model_cb: ClockModelCallback,
80    triggerbox_data_tx: Option<Sender<TriggerClockInfoRow>>,
81    max_acceptable_measurement_error: Duration,
82}
83
84#[derive(Debug, Clone)]
85pub enum Cmd {
86    TopAndPrescaler(TopAndPrescaler),
87    StopPulsesAndReset,
88    StartPulses,
89    SetDeviceName(InnerNameType),
90    SetAOut((f64, f64)),
91}
92
93impl TriggerboxDevice {
94    pub async fn new(
95        on_new_model_cb: ClockModelCallback,
96        device_path: String,
97        outq: Receiver<Cmd>,
98        triggerbox_data_tx: Option<Sender<TriggerClockInfoRow>>,
99        assert_device_name: NameType,
100        max_acceptable_measurement_error: std::time::Duration,
101        sleep_dur: std::time::Duration,
102    ) -> Result<Self> {
103        let baud_rate = 115_200;
104        let max_acceptable_measurement_error =
105            Duration::from_std(max_acceptable_measurement_error).unwrap();
106        let now = chrono::Utc::now();
107
108        // wait 1 second before first version query
109        let vquery_time = now + Duration::seconds(1);
110
111        debug!("Opening device at path {}", device_path);
112
113        let (ser, name) = match tokio::time::timeout(
114            std::time::Duration::from_millis(15_000),
115            serial_handshake(&device_path, baud_rate, sleep_dur),
116        )
117        .await
118        {
119            Ok(r) => r,
120            Err(elapsed) => Err(elapsed).map_err(anyhow::Error::from),
121        }
122        .with_context(|| format!("opening device {device_path}"))?;
123
124        if let Some(name) = &name {
125            let name_str = String::from_utf8_lossy(name);
126            debug!("Connected to device named \"{}\".", name_str);
127        } else {
128            debug!("Connected to unnamed device.");
129        }
130
131        if assert_device_name.is_some() && name != assert_device_name {
132            anyhow::bail!(
133                "Found name {}, but expected {}. ({:?} vs {:?}.)",
134                name_display(&name),
135                name_display(&assert_device_name),
136                name,
137                assert_device_name,
138            );
139        }
140
141        Ok(Self {
142            icr1_and_prescaler: None,
143            version_check_done: false,
144            qi: 0,
145            queries: BTreeMap::new(),
146            ser,
147            outq,
148            vquery_time,
149            last_time: vquery_time + Duration::seconds(1),
150            past_data: Vec::new(),
151            allow_requesting_clock_sync: false,
152            on_new_model_cb,
153            triggerbox_data_tx,
154            max_acceptable_measurement_error,
155        })
156    }
157
158    async fn write(&mut self, buf: &[u8]) -> tokio::io::Result<()> {
159        trace!("sending: \"{}\"", String::from_utf8_lossy(buf));
160        for byte in buf.iter() {
161            trace!("sending byte: {}", byte);
162        }
163        AsyncWriteExt::write_all(&mut self.ser, buf).await?;
164        Ok(())
165    }
166
167    async fn handle_host_command(&mut self, cmd: Cmd) -> Result<()> {
168        debug!("got command {:?}", cmd);
169        match cmd {
170            Cmd::TopAndPrescaler(new_value) => {
171                self._set_top_and_prescaler(new_value).await?;
172            }
173            Cmd::StopPulsesAndReset => {
174                debug!("will reset counters. dropping outstanding info requests.");
175                self.allow_requesting_clock_sync = false;
176                self.queries.clear();
177                self.past_data.clear();
178                (self.on_new_model_cb)(None);
179                self.write(b"S0").await?;
180            }
181            Cmd::StartPulses => {
182                self.allow_requesting_clock_sync = true;
183                self.write(b"S1").await?;
184            }
185            Cmd::SetDeviceName(name) => {
186                let computed_crc = format!("{:X}", arduino_udev::CRC_MAXIM.checksum(&name));
187                trace!("computed CRC: {:?}", computed_crc);
188
189                self.write(b"N=").await?;
190                self.write(&name).await?;
191                self.write(computed_crc.as_bytes()).await?;
192            }
193            Cmd::SetAOut((volts1, volts2)) => {
194                fn volts_to_dac(volts: f64) -> u16 {
195                    // Convert voltage to fraction and clamp.
196                    let frac = (volts / 4.096).clamp(0.0, 1.0);
197                    // Compute integer DAC value.
198                    let val: u16 = (frac * 4095.0).round() as u16;
199                    val
200                }
201                let val1 = volts_to_dac(volts1);
202                let val2 = volts_to_dac(volts2);
203
204                self.write(b"O=").await?;
205                self.write(&val1.to_le_bytes()).await?;
206                self.write(&val2.to_le_bytes()).await?;
207                self.write(b"x").await?;
208
209                // Now wait for return value.
210                tokio::time::sleep(std::time::Duration::from_millis(50)).await;
211
212                let mut buf = vec![0; 100];
213                let len = self.ser.read(&mut buf).await?;
214                let buf = &buf[..len];
215                debug!("AOUT ignoring values: {:?}", buf);
216            }
217        }
218        Ok(())
219    }
220
221    /// Run forever, handling interaction with the triggerbox hardware device.
222    ///
223    /// Drop all instances of the `Sender<Cmd>` which could send messages to the
224    /// `Receiver<Cmd>` passed to [Self::new] to exit.
225    pub async fn run_forever(
226        mut self: TriggerboxDevice,
227        query_dt: std::time::Duration,
228    ) -> Result<()> {
229        let query_dt = Duration::from_std(query_dt)?;
230
231        let mut now = chrono::Utc::now();
232
233        let connect_time = now;
234
235        let mut buf: Vec<u8> = Vec::new();
236        let mut read_buf: Vec<u8> = vec![0; 100];
237        let mut version_check_started = false;
238        let mut new_data = false;
239        let mut interval = tokio::time::interval(std::time::Duration::from_millis(100));
240
241        fn update_read_buffer(n_bytes_read: usize, read_buf: &[u8], buf: &mut Vec<u8>) {
242            for i in 0..n_bytes_read {
243                let byte = read_buf[i];
244                trace!(
245                    "read byte {} (char {})",
246                    byte,
247                    String::from_utf8_lossy(&read_buf[i..i + 1])
248                );
249                buf.push(byte);
250            }
251        }
252
253        loop {
254            if self.version_check_done {
255                tokio::select! {
256                    // Handle command queue iff version check done.
257                    opt_cmd_tup = self.outq.recv() => {
258                        match opt_cmd_tup {
259                            Some(cmd) => {
260                                self.handle_host_command(cmd).await?;
261                            }
262                            None => {
263                                // no more commands, sender hung up
264                                info!("exiting run loop");
265                                return Ok(());
266                            }
267                        }
268                    },
269                    res_r = self.ser.read(&mut read_buf) => {
270                        let n_bytes_read = res_r?;
271                        if n_bytes_read > 0 {
272                            update_read_buffer(n_bytes_read,&read_buf,&mut buf);
273                            new_data = true;
274                        }
275                    },
276                    _ = interval.tick() => {}
277                }
278            } else {
279                // Same as above except `self.outq` is not checked. This is done
280                // at startup before the version number is confirmed.
281                tokio::select! {
282                    res_r = self.ser.read(&mut read_buf) => {
283                        let n_bytes_read = res_r?;
284                        if n_bytes_read > 0 {
285                            update_read_buffer(n_bytes_read,&read_buf,&mut buf);
286                            new_data = true;
287                        }
288                    },
289                    _ = interval.tick() => {}
290                }
291            }
292
293            // handle pending data, if any
294            if new_data {
295                buf = self.handle_data_from_device(buf).await?;
296                new_data = false;
297            }
298
299            now = chrono::Utc::now();
300
301            if self.version_check_done {
302                if self.allow_requesting_clock_sync
303                    & (now.signed_duration_since(self.last_time) > query_dt)
304                {
305                    // request sample
306                    debug!("making clock sample request. qi: {}, now: {}", self.qi, now);
307                    self.queries.insert(self.qi, now);
308                    let send_buf = [b'P', self.qi];
309                    self.write(&send_buf).await?;
310                    self.qi = self.qi.wrapping_add(1);
311                    self.last_time = now;
312                }
313            } else {
314                // request firmware version
315                if !version_check_started && now >= self.vquery_time {
316                    info!("checking firmware version");
317                    self.write(b"V?").await?;
318                    version_check_started = true;
319                    self.vquery_time = now;
320                }
321
322                // retry every second
323                if now.signed_duration_since(self.vquery_time) > Duration::seconds(1) {
324                    version_check_started = false;
325                }
326                // give up after 20 seconds
327                if now.signed_duration_since(connect_time) > Duration::seconds(20) {
328                    return Err(anyhow::anyhow!("no version response"));
329                }
330            }
331        }
332    }
333
334    async fn _set_top_and_prescaler(&mut self, new_value: TopAndPrescaler) -> Result<()> {
335        use byteorder::{ByteOrder, LittleEndian};
336
337        let mut buf = [0, 0, 0];
338        LittleEndian::write_u16(&mut buf[0..2], new_value.avr_icr1());
339        buf[2] = new_value.prescaler_key();
340
341        self.icr1_and_prescaler = Some(new_value);
342
343        self.write(b"T=").await?;
344        self.write(&buf).await?;
345        Ok(())
346    }
347
348    async fn _handle_returned_timestamp(
349        &mut self,
350        qi: u8,
351        pulsenumber: u32,
352        count: u16,
353    ) -> Result<()> {
354        debug!(
355            "got returned timestamp with qi: {}, pulsenumber: {}, count: {}",
356            qi, pulsenumber, count
357        );
358        let now = chrono::Utc::now();
359        while self.queries.len() > 50 {
360            self.queries.clear();
361            error!("too many outstanding queries");
362        }
363
364        let send_timestamp = match self.queries.remove(&qi) {
365            Some(send_timestamp) => send_timestamp,
366            None => {
367                warn!("could not find original data for query {:?}", qi);
368                return Ok(());
369            }
370        };
371        trace!("this query has send_timestamp: {}", send_timestamp);
372
373        let max_error = now.signed_duration_since(send_timestamp);
374        if max_error > self.max_acceptable_measurement_error {
375            debug!("clock sample took {:?}. Ignoring value.", max_error);
376            return Ok(());
377        }
378
379        trace!("max_error: {:?}", max_error);
380
381        let ino_time_estimate = send_timestamp + (max_error / 2);
382
383        match &self.icr1_and_prescaler {
384            Some(s) => {
385                let frac = count as f64 / s.avr_icr1() as f64;
386                debug_assert!(0.0 <= frac);
387                debug_assert!(frac <= 1.0);
388                let ino_stamp = na::convert(pulsenumber as f64 + frac);
389
390                if let Some(ref tbox_tx) = self.triggerbox_data_tx {
391                    // send our newly acquired data to be saved to disk
392                    let to_save = TriggerClockInfoRow {
393                        start_timestamp: send_timestamp,
394                        framecount: pulsenumber as i64,
395                        tcnt: (frac * 255.0) as u8,
396                        stop_timestamp: now,
397                    };
398                    match tbox_tx.send(to_save).await {
399                        Ok(()) => {}
400                        Err(e) => {
401                            warn!("ignoring {}", e);
402                        }
403                    }
404                }
405
406                // delete old data
407                while self.past_data.len() >= 100 {
408                    self.past_data.remove(0);
409                }
410
411                self.past_data.push((
412                    ino_stamp,
413                    datetime_conversion::datetime_to_f64(&ino_time_estimate),
414                ));
415
416                if self.past_data.len() >= 5 {
417                    let (gain, offset, residuals) = fit_time_model(&self.past_data)
418                        .map_err(|e| anyhow::anyhow!("lstsq err: {}", e))?;
419
420                    let n_measurements = self.past_data.len() as u64;
421                    let per_point_residual = residuals / n_measurements as f64;
422                    // TODO only accept this if residuals less than some amount?
423                    debug!(
424                        "new: ClockModel{{gain: {}, offset: {}}}, per_point_residual: {}",
425                        gain, offset, per_point_residual
426                    );
427                    (self.on_new_model_cb)(Some(ClockModel {
428                        gain,
429                        offset,
430                        residuals,
431                        n_measurements,
432                    }));
433                }
434            }
435            None => {
436                warn!("No clock measurements until framerate set.");
437            }
438        }
439        Ok(())
440    }
441
442    fn _handle_version(&mut self, value: u8, _pulsenumber: u32, _count: u16) -> Result<()> {
443        trace!("got returned version with value: {}", value);
444        assert_eq!(value, DEVICE_FIRMWARE_VERSION);
445        self.vquery_time = chrono::Utc::now();
446        self.version_check_done = true;
447        info!("connected to triggerbox firmware version {}", value);
448        Ok(())
449    }
450
451    async fn handle_data_from_device(&mut self, buf: Vec<u8>) -> Result<Vec<u8>> {
452        if buf.len() >= 3 {
453            // header, length, checksum is minimum
454            let mut valid_n_chars = None;
455
456            let packet_type = buf[0] as char;
457            let payload_len = buf[1];
458
459            let min_valid_packet_size = 3 + payload_len as usize; // header (2) + payload + checksum (1)
460            if buf.len() >= min_valid_packet_size {
461                let expected_chksum = buf[2 + payload_len as usize];
462
463                let check_buf = &buf[2..buf.len() - 1];
464                let bytes = check_buf;
465                let actual_chksum = bytes.iter().fold(0, |acc: u8, x| acc.wrapping_add(*x));
466
467                if actual_chksum == expected_chksum {
468                    trace!("checksum OK");
469                    valid_n_chars = Some(bytes.len() + 3)
470                } else {
471                    return Err(anyhow::anyhow!("checksum mismatch"));
472                }
473
474                if (packet_type == 'P') | (packet_type == 'V') {
475                    assert!(payload_len == 7);
476                    let value = bytes[0];
477
478                    use byteorder::{ByteOrder, LittleEndian};
479                    let pulsenumber = LittleEndian::read_u32(&bytes[1..5]);
480                    let count = LittleEndian::read_u16(&bytes[5..7]);
481
482                    match packet_type {
483                        'P' => {
484                            self._handle_returned_timestamp(value, pulsenumber, count)
485                                .await?
486                        }
487                        'V' => self._handle_version(value, pulsenumber, count)?,
488                        _ => unreachable!(),
489                    };
490                }
491            }
492
493            if let Some(n_used_chars) = valid_n_chars {
494                return Ok(buf[n_used_chars..].to_vec());
495            }
496        }
497        Ok(buf)
498    }
499}
500
501fn fit_time_model(past_data: &[(f64, f64)]) -> Result<(f64, f64, f64), &'static str> {
502    use na::{OMatrix, OVector, U2};
503
504    let mut a: Vec<f64> = Vec::with_capacity(past_data.len() * 2);
505    let mut b: Vec<f64> = Vec::with_capacity(past_data.len());
506
507    for row in past_data.iter() {
508        a.push(row.0);
509        a.push(1.0);
510        b.push(row.1);
511    }
512    let a = OMatrix::<f64, na::Dyn, U2>::from_row_slice(&a);
513    let b = OVector::<f64, na::Dyn>::from_row_slice(&b);
514
515    let epsilon = 1e-10;
516    let results = lstsq::lstsq(&a, &b, epsilon)?;
517
518    let gain = results.solution[0];
519    let offset = results.solution[1];
520    let residuals = results.residuals;
521
522    Ok((gain, offset, residuals))
523}
524
525#[test]
526fn test_fit_time_model() {
527    let epsilon = 1e-12;
528
529    let data = vec![(0.0, 0.0), (1.0, 1.0), (2.0, 2.0), (3.0, 3.0)];
530    let (gain, offset, _residuals) = fit_time_model(&data).unwrap();
531    assert!((gain - 1.0).abs() < epsilon);
532    assert!((offset - 0.0).abs() < epsilon);
533
534    let data = vec![(0.0, 12.0), (1.0, 22.0), (2.0, 32.0), (3.0, 42.0)];
535    let (gain, offset, _residuals) = fit_time_model(&data).unwrap();
536    assert!((gain - 10.0).abs() < epsilon);
537    assert!((offset - 12.0).abs() < epsilon);
538}
539
540#[derive(Clone, Debug)]
541pub struct TriggerboxOptions {
542    pub device_path: String,
543    pub query_dt: std::time::Duration,
544    pub assert_device_name: NameType,
545    pub max_acceptable_measurement_error: std::time::Duration,
546    pub sleep_dur: std::time::Duration,
547}
548
549pub async fn run_triggerbox(
550    on_new_model_cb: ClockModelCallback,
551    outq: Receiver<Cmd>,
552    triggerbox_data_tx: Option<Sender<TriggerClockInfoRow>>,
553    opts: TriggerboxOptions,
554) -> Result<()> {
555    let TriggerboxOptions {
556        device_path,
557        query_dt,
558        assert_device_name,
559        max_acceptable_measurement_error,
560        sleep_dur,
561    } = opts;
562
563    let triggerbox = TriggerboxDevice::new(
564        on_new_model_cb,
565        device_path,
566        outq,
567        triggerbox_data_tx,
568        assert_device_name,
569        max_acceptable_measurement_error,
570        sleep_dur,
571    )
572    .await?;
573    triggerbox.run_forever(query_dt).await
574}
575
576fn get_rate(rate_ideal: f64, prescaler: Prescaler) -> (u16, f64) {
577    let xtal = 16e6; // 16 MHz clock
578    let base_clock = xtal / prescaler.as_f64();
579    let new_top_ideal = base_clock / rate_ideal;
580    let new_icr1_f64 = new_top_ideal.round();
581    let new_icr1: u16 = if new_icr1_f64 > 0xFFFF as f64 {
582        0xFFFF
583    } else if new_icr1_f64 < 0.0 {
584        0
585    } else {
586        new_icr1_f64 as u16
587    };
588    let rate_actual = base_clock / new_icr1 as f64;
589    (new_icr1, rate_actual)
590}
591
592/// Given an ideal frame rate (in frames per second), compute the triggerbox
593/// command which best approximates this frame rate.
594///
595/// Returns the triggerbox command and the expected actual frame rate (in frames
596/// per second).
597pub fn make_trig_fps_cmd(rate_ideal: f64) -> (Cmd, f64) {
598    let (top_8, rate_actual_8) = get_rate(rate_ideal, Prescaler::Scale8);
599    let (top_64, rate_actual_64) = get_rate(rate_ideal, Prescaler::Scale64);
600
601    let error_8 = (rate_ideal - rate_actual_8).abs();
602    let error_64 = (rate_ideal - rate_actual_64).abs();
603
604    let (top, rate_actual, prescaler) = if error_8 < error_64 {
605        (top_8, rate_actual_8, Prescaler::Scale8)
606    } else {
607        (top_64, rate_actual_64, Prescaler::Scale64)
608    };
609
610    (
611        Cmd::TopAndPrescaler(TopAndPrescaler::new_avr(top, prescaler)),
612        rate_actual,
613    )
614}