1#[macro_use]
2extern crate log;
3
4mod datetime_conversion;
5
6mod arduino_udev;
7use crate::arduino_udev::serial_handshake;
8
9use anyhow::{Context, Result};
10use chrono::Duration;
11
12use nalgebra as na;
13
14use std::collections::BTreeMap;
15use tokio::{
16 io::{AsyncReadExt, AsyncWriteExt},
17 sync::mpsc::{Receiver, Sender},
18};
19
20use braid_triggerbox_comms::{Prescaler, TopAndPrescaler, DEVICE_FIRMWARE_VERSION};
21
22pub const DEVICE_NAME_LEN: usize = 8;
24
25pub type InnerNameType = [u8; DEVICE_NAME_LEN];
26pub type NameType = Option<InnerNameType>;
27
28pub type ClockModelCallback = Box<dyn FnMut(Option<ClockModel>) + Send>;
29
30pub fn to_name_type(x: &str) -> anyhow::Result<InnerNameType> {
31 let mut name = [0; DEVICE_NAME_LEN];
32 let bytes = x.as_bytes();
33 if bytes.len() > DEVICE_NAME_LEN {
34 anyhow::bail!("Maximum name length ({} chars) exceeded.", DEVICE_NAME_LEN);
35 }
36 name[..bytes.len()].copy_from_slice(bytes);
37 Ok(name)
38}
39
40pub fn name_display(name: &NameType) -> String {
41 if let Some(name) = name {
42 format!("\"{}\"", String::from_utf8_lossy(name))
43 } else {
44 "none".into()
45 }
46}
47
48#[derive(Debug, PartialEq, Clone)]
51pub struct ClockModel {
52 pub gain: f64,
53 pub offset: f64,
54 pub residuals: f64,
55 pub n_measurements: u64,
56}
57
58#[derive(Debug)]
59pub struct TriggerClockInfoRow {
60 pub start_timestamp: chrono::DateTime<chrono::Utc>,
62 pub framecount: i64,
63 pub tcnt: u8,
64 pub stop_timestamp: chrono::DateTime<chrono::Utc>,
65}
66
67pub struct TriggerboxDevice {
69 icr1_and_prescaler: Option<TopAndPrescaler>,
70 version_check_done: bool,
71 qi: u8,
72 queries: BTreeMap<u8, chrono::DateTime<chrono::Utc>>,
73 ser: tokio_serial::SerialStream,
74 outq: Receiver<Cmd>,
75 vquery_time: chrono::DateTime<chrono::Utc>,
76 last_time: chrono::DateTime<chrono::Utc>,
77 past_data: Vec<(f64, f64)>,
78 allow_requesting_clock_sync: bool,
79 on_new_model_cb: ClockModelCallback,
80 triggerbox_data_tx: Option<Sender<TriggerClockInfoRow>>,
81 max_acceptable_measurement_error: Duration,
82}
83
84#[derive(Debug, Clone)]
85pub enum Cmd {
86 TopAndPrescaler(TopAndPrescaler),
87 StopPulsesAndReset,
88 StartPulses,
89 SetDeviceName(InnerNameType),
90 SetAOut((f64, f64)),
91}
92
93impl TriggerboxDevice {
94 pub async fn new(
95 on_new_model_cb: ClockModelCallback,
96 device_path: String,
97 outq: Receiver<Cmd>,
98 triggerbox_data_tx: Option<Sender<TriggerClockInfoRow>>,
99 assert_device_name: NameType,
100 max_acceptable_measurement_error: std::time::Duration,
101 sleep_dur: std::time::Duration,
102 ) -> Result<Self> {
103 let baud_rate = 115_200;
104 let max_acceptable_measurement_error =
105 Duration::from_std(max_acceptable_measurement_error).unwrap();
106 let now = chrono::Utc::now();
107
108 let vquery_time = now + Duration::seconds(1);
110
111 debug!("Opening device at path {}", device_path);
112
113 let (ser, name) = match tokio::time::timeout(
114 std::time::Duration::from_millis(15_000),
115 serial_handshake(&device_path, baud_rate, sleep_dur),
116 )
117 .await
118 {
119 Ok(r) => r,
120 Err(elapsed) => Err(elapsed).map_err(anyhow::Error::from),
121 }
122 .with_context(|| format!("opening device {device_path}"))?;
123
124 if let Some(name) = &name {
125 let name_str = String::from_utf8_lossy(name);
126 debug!("Connected to device named \"{}\".", name_str);
127 } else {
128 debug!("Connected to unnamed device.");
129 }
130
131 if assert_device_name.is_some() && name != assert_device_name {
132 anyhow::bail!(
133 "Found name {}, but expected {}. ({:?} vs {:?}.)",
134 name_display(&name),
135 name_display(&assert_device_name),
136 name,
137 assert_device_name,
138 );
139 }
140
141 Ok(Self {
142 icr1_and_prescaler: None,
143 version_check_done: false,
144 qi: 0,
145 queries: BTreeMap::new(),
146 ser,
147 outq,
148 vquery_time,
149 last_time: vquery_time + Duration::seconds(1),
150 past_data: Vec::new(),
151 allow_requesting_clock_sync: false,
152 on_new_model_cb,
153 triggerbox_data_tx,
154 max_acceptable_measurement_error,
155 })
156 }
157
158 async fn write(&mut self, buf: &[u8]) -> tokio::io::Result<()> {
159 trace!("sending: \"{}\"", String::from_utf8_lossy(buf));
160 for byte in buf.iter() {
161 trace!("sending byte: {}", byte);
162 }
163 AsyncWriteExt::write_all(&mut self.ser, buf).await?;
164 Ok(())
165 }
166
167 async fn handle_host_command(&mut self, cmd: Cmd) -> Result<()> {
168 debug!("got command {:?}", cmd);
169 match cmd {
170 Cmd::TopAndPrescaler(new_value) => {
171 self._set_top_and_prescaler(new_value).await?;
172 }
173 Cmd::StopPulsesAndReset => {
174 debug!("will reset counters. dropping outstanding info requests.");
175 self.allow_requesting_clock_sync = false;
176 self.queries.clear();
177 self.past_data.clear();
178 (self.on_new_model_cb)(None);
179 self.write(b"S0").await?;
180 }
181 Cmd::StartPulses => {
182 self.allow_requesting_clock_sync = true;
183 self.write(b"S1").await?;
184 }
185 Cmd::SetDeviceName(name) => {
186 let computed_crc = format!("{:X}", arduino_udev::CRC_MAXIM.checksum(&name));
187 trace!("computed CRC: {:?}", computed_crc);
188
189 self.write(b"N=").await?;
190 self.write(&name).await?;
191 self.write(computed_crc.as_bytes()).await?;
192 }
193 Cmd::SetAOut((volts1, volts2)) => {
194 fn volts_to_dac(volts: f64) -> u16 {
195 let frac = (volts / 4.096).clamp(0.0, 1.0);
197 let val: u16 = (frac * 4095.0).round() as u16;
199 val
200 }
201 let val1 = volts_to_dac(volts1);
202 let val2 = volts_to_dac(volts2);
203
204 self.write(b"O=").await?;
205 self.write(&val1.to_le_bytes()).await?;
206 self.write(&val2.to_le_bytes()).await?;
207 self.write(b"x").await?;
208
209 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
211
212 let mut buf = vec![0; 100];
213 let len = self.ser.read(&mut buf).await?;
214 let buf = &buf[..len];
215 debug!("AOUT ignoring values: {:?}", buf);
216 }
217 }
218 Ok(())
219 }
220
221 pub async fn run_forever(
226 mut self: TriggerboxDevice,
227 query_dt: std::time::Duration,
228 ) -> Result<()> {
229 let query_dt = Duration::from_std(query_dt)?;
230
231 let mut now = chrono::Utc::now();
232
233 let connect_time = now;
234
235 let mut buf: Vec<u8> = Vec::new();
236 let mut read_buf: Vec<u8> = vec![0; 100];
237 let mut version_check_started = false;
238 let mut new_data = false;
239 let mut interval = tokio::time::interval(std::time::Duration::from_millis(100));
240
241 fn update_read_buffer(n_bytes_read: usize, read_buf: &[u8], buf: &mut Vec<u8>) {
242 for i in 0..n_bytes_read {
243 let byte = read_buf[i];
244 trace!(
245 "read byte {} (char {})",
246 byte,
247 String::from_utf8_lossy(&read_buf[i..i + 1])
248 );
249 buf.push(byte);
250 }
251 }
252
253 loop {
254 if self.version_check_done {
255 tokio::select! {
256 opt_cmd_tup = self.outq.recv() => {
258 match opt_cmd_tup {
259 Some(cmd) => {
260 self.handle_host_command(cmd).await?;
261 }
262 None => {
263 info!("exiting run loop");
265 return Ok(());
266 }
267 }
268 },
269 res_r = self.ser.read(&mut read_buf) => {
270 let n_bytes_read = res_r?;
271 if n_bytes_read > 0 {
272 update_read_buffer(n_bytes_read,&read_buf,&mut buf);
273 new_data = true;
274 }
275 },
276 _ = interval.tick() => {}
277 }
278 } else {
279 tokio::select! {
282 res_r = self.ser.read(&mut read_buf) => {
283 let n_bytes_read = res_r?;
284 if n_bytes_read > 0 {
285 update_read_buffer(n_bytes_read,&read_buf,&mut buf);
286 new_data = true;
287 }
288 },
289 _ = interval.tick() => {}
290 }
291 }
292
293 if new_data {
295 buf = self.handle_data_from_device(buf).await?;
296 new_data = false;
297 }
298
299 now = chrono::Utc::now();
300
301 if self.version_check_done {
302 if self.allow_requesting_clock_sync
303 & (now.signed_duration_since(self.last_time) > query_dt)
304 {
305 debug!("making clock sample request. qi: {}, now: {}", self.qi, now);
307 self.queries.insert(self.qi, now);
308 let send_buf = [b'P', self.qi];
309 self.write(&send_buf).await?;
310 self.qi = self.qi.wrapping_add(1);
311 self.last_time = now;
312 }
313 } else {
314 if !version_check_started && now >= self.vquery_time {
316 info!("checking firmware version");
317 self.write(b"V?").await?;
318 version_check_started = true;
319 self.vquery_time = now;
320 }
321
322 if now.signed_duration_since(self.vquery_time) > Duration::seconds(1) {
324 version_check_started = false;
325 }
326 if now.signed_duration_since(connect_time) > Duration::seconds(20) {
328 return Err(anyhow::anyhow!("no version response"));
329 }
330 }
331 }
332 }
333
334 async fn _set_top_and_prescaler(&mut self, new_value: TopAndPrescaler) -> Result<()> {
335 use byteorder::{ByteOrder, LittleEndian};
336
337 let mut buf = [0, 0, 0];
338 LittleEndian::write_u16(&mut buf[0..2], new_value.avr_icr1());
339 buf[2] = new_value.prescaler_key();
340
341 self.icr1_and_prescaler = Some(new_value);
342
343 self.write(b"T=").await?;
344 self.write(&buf).await?;
345 Ok(())
346 }
347
348 async fn _handle_returned_timestamp(
349 &mut self,
350 qi: u8,
351 pulsenumber: u32,
352 count: u16,
353 ) -> Result<()> {
354 debug!(
355 "got returned timestamp with qi: {}, pulsenumber: {}, count: {}",
356 qi, pulsenumber, count
357 );
358 let now = chrono::Utc::now();
359 while self.queries.len() > 50 {
360 self.queries.clear();
361 error!("too many outstanding queries");
362 }
363
364 let send_timestamp = match self.queries.remove(&qi) {
365 Some(send_timestamp) => send_timestamp,
366 None => {
367 warn!("could not find original data for query {:?}", qi);
368 return Ok(());
369 }
370 };
371 trace!("this query has send_timestamp: {}", send_timestamp);
372
373 let max_error = now.signed_duration_since(send_timestamp);
374 if max_error > self.max_acceptable_measurement_error {
375 debug!("clock sample took {:?}. Ignoring value.", max_error);
376 return Ok(());
377 }
378
379 trace!("max_error: {:?}", max_error);
380
381 let ino_time_estimate = send_timestamp + (max_error / 2);
382
383 match &self.icr1_and_prescaler {
384 Some(s) => {
385 let frac = count as f64 / s.avr_icr1() as f64;
386 debug_assert!(0.0 <= frac);
387 debug_assert!(frac <= 1.0);
388 let ino_stamp = na::convert(pulsenumber as f64 + frac);
389
390 if let Some(ref tbox_tx) = self.triggerbox_data_tx {
391 let to_save = TriggerClockInfoRow {
393 start_timestamp: send_timestamp,
394 framecount: pulsenumber as i64,
395 tcnt: (frac * 255.0) as u8,
396 stop_timestamp: now,
397 };
398 match tbox_tx.send(to_save).await {
399 Ok(()) => {}
400 Err(e) => {
401 warn!("ignoring {}", e);
402 }
403 }
404 }
405
406 while self.past_data.len() >= 100 {
408 self.past_data.remove(0);
409 }
410
411 self.past_data.push((
412 ino_stamp,
413 datetime_conversion::datetime_to_f64(&ino_time_estimate),
414 ));
415
416 if self.past_data.len() >= 5 {
417 let (gain, offset, residuals) = fit_time_model(&self.past_data)
418 .map_err(|e| anyhow::anyhow!("lstsq err: {}", e))?;
419
420 let n_measurements = self.past_data.len() as u64;
421 let per_point_residual = residuals / n_measurements as f64;
422 debug!(
424 "new: ClockModel{{gain: {}, offset: {}}}, per_point_residual: {}",
425 gain, offset, per_point_residual
426 );
427 (self.on_new_model_cb)(Some(ClockModel {
428 gain,
429 offset,
430 residuals,
431 n_measurements,
432 }));
433 }
434 }
435 None => {
436 warn!("No clock measurements until framerate set.");
437 }
438 }
439 Ok(())
440 }
441
442 fn _handle_version(&mut self, value: u8, _pulsenumber: u32, _count: u16) -> Result<()> {
443 trace!("got returned version with value: {}", value);
444 assert_eq!(value, DEVICE_FIRMWARE_VERSION);
445 self.vquery_time = chrono::Utc::now();
446 self.version_check_done = true;
447 info!("connected to triggerbox firmware version {}", value);
448 Ok(())
449 }
450
451 async fn handle_data_from_device(&mut self, buf: Vec<u8>) -> Result<Vec<u8>> {
452 if buf.len() >= 3 {
453 let mut valid_n_chars = None;
455
456 let packet_type = buf[0] as char;
457 let payload_len = buf[1];
458
459 let min_valid_packet_size = 3 + payload_len as usize; if buf.len() >= min_valid_packet_size {
461 let expected_chksum = buf[2 + payload_len as usize];
462
463 let check_buf = &buf[2..buf.len() - 1];
464 let bytes = check_buf;
465 let actual_chksum = bytes.iter().fold(0, |acc: u8, x| acc.wrapping_add(*x));
466
467 if actual_chksum == expected_chksum {
468 trace!("checksum OK");
469 valid_n_chars = Some(bytes.len() + 3)
470 } else {
471 return Err(anyhow::anyhow!("checksum mismatch"));
472 }
473
474 if (packet_type == 'P') | (packet_type == 'V') {
475 assert!(payload_len == 7);
476 let value = bytes[0];
477
478 use byteorder::{ByteOrder, LittleEndian};
479 let pulsenumber = LittleEndian::read_u32(&bytes[1..5]);
480 let count = LittleEndian::read_u16(&bytes[5..7]);
481
482 match packet_type {
483 'P' => {
484 self._handle_returned_timestamp(value, pulsenumber, count)
485 .await?
486 }
487 'V' => self._handle_version(value, pulsenumber, count)?,
488 _ => unreachable!(),
489 };
490 }
491 }
492
493 if let Some(n_used_chars) = valid_n_chars {
494 return Ok(buf[n_used_chars..].to_vec());
495 }
496 }
497 Ok(buf)
498 }
499}
500
501fn fit_time_model(past_data: &[(f64, f64)]) -> Result<(f64, f64, f64), &'static str> {
502 use na::{OMatrix, OVector, U2};
503
504 let mut a: Vec<f64> = Vec::with_capacity(past_data.len() * 2);
505 let mut b: Vec<f64> = Vec::with_capacity(past_data.len());
506
507 for row in past_data.iter() {
508 a.push(row.0);
509 a.push(1.0);
510 b.push(row.1);
511 }
512 let a = OMatrix::<f64, na::Dyn, U2>::from_row_slice(&a);
513 let b = OVector::<f64, na::Dyn>::from_row_slice(&b);
514
515 let epsilon = 1e-10;
516 let results = lstsq::lstsq(&a, &b, epsilon)?;
517
518 let gain = results.solution[0];
519 let offset = results.solution[1];
520 let residuals = results.residuals;
521
522 Ok((gain, offset, residuals))
523}
524
525#[test]
526fn test_fit_time_model() {
527 let epsilon = 1e-12;
528
529 let data = vec![(0.0, 0.0), (1.0, 1.0), (2.0, 2.0), (3.0, 3.0)];
530 let (gain, offset, _residuals) = fit_time_model(&data).unwrap();
531 assert!((gain - 1.0).abs() < epsilon);
532 assert!((offset - 0.0).abs() < epsilon);
533
534 let data = vec![(0.0, 12.0), (1.0, 22.0), (2.0, 32.0), (3.0, 42.0)];
535 let (gain, offset, _residuals) = fit_time_model(&data).unwrap();
536 assert!((gain - 10.0).abs() < epsilon);
537 assert!((offset - 12.0).abs() < epsilon);
538}
539
540#[derive(Clone, Debug)]
541pub struct TriggerboxOptions {
542 pub device_path: String,
543 pub query_dt: std::time::Duration,
544 pub assert_device_name: NameType,
545 pub max_acceptable_measurement_error: std::time::Duration,
546 pub sleep_dur: std::time::Duration,
547}
548
549pub async fn run_triggerbox(
550 on_new_model_cb: ClockModelCallback,
551 outq: Receiver<Cmd>,
552 triggerbox_data_tx: Option<Sender<TriggerClockInfoRow>>,
553 opts: TriggerboxOptions,
554) -> Result<()> {
555 let TriggerboxOptions {
556 device_path,
557 query_dt,
558 assert_device_name,
559 max_acceptable_measurement_error,
560 sleep_dur,
561 } = opts;
562
563 let triggerbox = TriggerboxDevice::new(
564 on_new_model_cb,
565 device_path,
566 outq,
567 triggerbox_data_tx,
568 assert_device_name,
569 max_acceptable_measurement_error,
570 sleep_dur,
571 )
572 .await?;
573 triggerbox.run_forever(query_dt).await
574}
575
576fn get_rate(rate_ideal: f64, prescaler: Prescaler) -> (u16, f64) {
577 let xtal = 16e6; let base_clock = xtal / prescaler.as_f64();
579 let new_top_ideal = base_clock / rate_ideal;
580 let new_icr1_f64 = new_top_ideal.round();
581 let new_icr1: u16 = if new_icr1_f64 > 0xFFFF as f64 {
582 0xFFFF
583 } else if new_icr1_f64 < 0.0 {
584 0
585 } else {
586 new_icr1_f64 as u16
587 };
588 let rate_actual = base_clock / new_icr1 as f64;
589 (new_icr1, rate_actual)
590}
591
592pub fn make_trig_fps_cmd(rate_ideal: f64) -> (Cmd, f64) {
598 let (top_8, rate_actual_8) = get_rate(rate_ideal, Prescaler::Scale8);
599 let (top_64, rate_actual_64) = get_rate(rate_ideal, Prescaler::Scale64);
600
601 let error_8 = (rate_ideal - rate_actual_8).abs();
602 let error_64 = (rate_ideal - rate_actual_64).abs();
603
604 let (top, rate_actual, prescaler) = if error_8 < error_64 {
605 (top_8, rate_actual_8, Prescaler::Scale8)
606 } else {
607 (top_64, rate_actual_64, Prescaler::Scale64)
608 };
609
610 (
611 Cmd::TopAndPrescaler(TopAndPrescaler::new_avr(top, prescaler)),
612 rate_actual,
613 )
614}