Skip to main content

agent_sim/
protocol.rs

1use crate::load::LoadSpec;
2use crate::sim::types::{SignalType, SignalValue, SimCanFrame};
3use serde::{Deserialize, Serialize};
4use std::collections::BTreeMap;
5use thiserror::Error;
6use uuid::Uuid;
7
8#[derive(Debug, Error)]
9pub enum ProtocolError {
10    #[error("invalid duration: {0}")]
11    InvalidDuration(String),
12    #[error("invalid request: {0}")]
13    InvalidRequest(String),
14    #[error("serialization error: {0}")]
15    Serialization(#[from] serde_json::Error),
16}
17
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct Request {
20    pub id: Uuid,
21    pub action: RequestAction,
22}
23
24#[derive(Debug, Clone, Serialize, Deserialize)]
25#[serde(tag = "target", content = "payload", rename_all = "snake_case")]
26pub enum RequestAction {
27    Instance(InstanceAction),
28    Worker(WorkerAction),
29    Env(EnvAction),
30}
31
32#[derive(Debug, Clone, Serialize, Deserialize)]
33#[serde(tag = "type", content = "payload", rename_all = "snake_case")]
34pub enum InstanceAction {
35    Ping,
36    Load {
37        load_spec: LoadSpec,
38    },
39    Info,
40    Signals,
41    Reset,
42    Get {
43        selectors: Vec<String>,
44    },
45    Sample {
46        selectors: Vec<String>,
47    },
48    Set {
49        writes: BTreeMap<String, String>,
50    },
51    TimeStart,
52    TimePause,
53    TimeStep {
54        duration: String,
55    },
56    TimeSpeed {
57        multiplier: Option<f64>,
58    },
59    TimeStatus,
60    CanBuses,
61    CanAttach {
62        bus_name: String,
63        vcan_iface: String,
64    },
65    CanDetach {
66        bus_name: String,
67    },
68    CanLoadDbc {
69        bus_name: String,
70        path: String,
71    },
72    SharedList,
73    SharedAttach {
74        channel_name: String,
75        path: String,
76        writer: bool,
77        writer_session: String,
78    },
79    SharedGet {
80        channel_name: String,
81    },
82    CanSend {
83        bus_name: String,
84        arb_id: u32,
85        data_hex: String,
86        flags: Option<u8>,
87    },
88    InstanceStatus,
89    InstanceList,
90    Close,
91}
92
93#[derive(Debug, Clone, Serialize, Deserialize)]
94#[serde(tag = "type", content = "payload", rename_all = "snake_case")]
95pub enum WorkerAction {
96    CanBuses,
97    CanAttach {
98        bus_name: String,
99        vcan_iface: String,
100    },
101    CanDiscardPendingRx,
102    Step,
103}
104
105#[derive(Debug, Clone, Serialize, Deserialize)]
106#[serde(tag = "type", content = "payload", rename_all = "snake_case")]
107pub enum EnvAction {
108    Status {
109        env: String,
110    },
111    Reset {
112        env: String,
113    },
114    TimeStart {
115        env: String,
116    },
117    TimePause {
118        env: String,
119    },
120    TimeStep {
121        env: String,
122        duration: String,
123    },
124    TimeSpeed {
125        env: String,
126        multiplier: Option<f64>,
127    },
128    TimeStatus {
129        env: String,
130    },
131    CanBuses {
132        env: String,
133    },
134    CanLoadDbc {
135        env: String,
136        bus_name: String,
137        path: String,
138    },
139    CanSend {
140        env: String,
141        bus_name: String,
142        arb_id: u32,
143        data_hex: String,
144        flags: Option<u8>,
145    },
146    CanInspect {
147        env: String,
148        bus_name: String,
149    },
150    CanScheduleAdd {
151        env: String,
152        bus_name: String,
153        job_id: Option<String>,
154        arb_id: u32,
155        data_hex: String,
156        every: String,
157        flags: Option<u8>,
158    },
159    CanScheduleUpdate {
160        env: String,
161        job_id: String,
162        arb_id: u32,
163        data_hex: String,
164        every: String,
165        flags: Option<u8>,
166    },
167    CanScheduleRemove {
168        env: String,
169        job_id: String,
170    },
171    CanScheduleStop {
172        env: String,
173        job_id: String,
174    },
175    CanScheduleStart {
176        env: String,
177        job_id: String,
178    },
179    CanScheduleList {
180        env: String,
181        bus_name: Option<String>,
182    },
183    Close {
184        env: String,
185    },
186}
187
188#[derive(Debug, Clone, Serialize, Deserialize)]
189pub struct Response {
190    pub id: Uuid,
191    pub success: bool,
192    #[serde(skip_serializing_if = "Option::is_none")]
193    pub data: Option<ResponseData>,
194    #[serde(skip_serializing_if = "Option::is_none")]
195    pub error: Option<String>,
196}
197
198impl Response {
199    pub fn ok(id: Uuid, data: ResponseData) -> Self {
200        Self {
201            id,
202            success: true,
203            data: Some(data),
204            error: None,
205        }
206    }
207
208    pub fn err(id: Uuid, message: impl Into<String>) -> Self {
209        Self {
210            id,
211            success: false,
212            data: None,
213            error: Some(message.into()),
214        }
215    }
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
219#[serde(tag = "kind", content = "value", rename_all = "snake_case")]
220pub enum ResponseData {
221    Ack,
222    Loaded {
223        libpath: String,
224        signal_count: usize,
225    },
226    ProjectInfo {
227        libpath: String,
228        tick_duration_us: u32,
229        signal_count: usize,
230    },
231    Signals {
232        signals: Vec<SignalData>,
233    },
234    SignalValues {
235        values: Vec<SignalValueData>,
236    },
237    SignalSample {
238        tick: u64,
239        time_us: u64,
240        values: Vec<SignalValueData>,
241    },
242    SetResult {
243        writes_applied: usize,
244    },
245    TimeStatus {
246        state: TimeStateData,
247        elapsed_ticks: u64,
248        elapsed_time_us: u64,
249        speed: f64,
250    },
251    TimeAdvanced {
252        requested_us: u64,
253        advanced_ticks: u64,
254        advanced_us: u64,
255    },
256    Speed {
257        speed: f64,
258    },
259    CanBuses {
260        buses: Vec<CanBusData>,
261    },
262    CanSend {
263        bus: String,
264        arb_id: u32,
265        len: u8,
266    },
267    CanInspect {
268        bus: String,
269        frames: Vec<CanFrameData>,
270    },
271    CanSchedules {
272        schedules: Vec<CanScheduleData>,
273    },
274    DbcLoaded {
275        bus: String,
276        signal_count: usize,
277    },
278    SharedChannels {
279        channels: Vec<SharedChannelData>,
280    },
281    SharedValues {
282        channel: String,
283        slots: Vec<SharedSlotValueData>,
284    },
285    WatchSamples {
286        samples: Vec<WatchSampleData>,
287    },
288    RecipeResult {
289        recipe: String,
290        dry_run: bool,
291        steps_executed: usize,
292        steps: Vec<RecipeStepResultData>,
293    },
294    EnvStatus {
295        env: String,
296        running: bool,
297        instance_count: usize,
298        tick_duration_us: u32,
299    },
300    InstanceStatus {
301        instance: String,
302        socket_path: String,
303        running: bool,
304        env: Option<String>,
305    },
306    InstanceList {
307        instances: Vec<InstanceInfoData>,
308    },
309}
310
311#[derive(Debug, Clone, Serialize, Deserialize)]
312pub struct SignalData {
313    pub id: u32,
314    pub name: String,
315    pub signal_type: SignalType,
316    pub units: Option<String>,
317}
318
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct SignalValueData {
321    pub id: u32,
322    pub name: String,
323    pub signal_type: SignalType,
324    pub value: SignalValue,
325    pub units: Option<String>,
326}
327
328#[derive(Debug, Clone, Serialize, Deserialize)]
329pub struct WatchSampleData {
330    pub tick: u64,
331    pub time_us: u64,
332    pub signal: String,
333    pub value: SignalValue,
334}
335
336#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
337#[serde(rename_all = "snake_case")]
338pub enum TimeStateData {
339    Paused,
340    Running,
341}
342
343#[derive(Debug, Clone, Serialize, Deserialize)]
344pub struct InstanceInfoData {
345    pub name: String,
346    pub socket_path: String,
347    pub running: bool,
348    pub env: Option<String>,
349}
350
351#[derive(Debug, Clone, Serialize, Deserialize)]
352pub struct CanBusData {
353    pub id: u32,
354    pub name: String,
355    pub bitrate: u32,
356    pub bitrate_data: u32,
357    pub fd_capable: bool,
358    pub attached_iface: Option<String>,
359}
360
361#[derive(Debug, Clone, Serialize, Deserialize)]
362pub struct CanBusFramesData {
363    pub bus_name: String,
364    pub frames: Vec<CanFrameWireData>,
365}
366
367#[derive(Debug, Clone, Serialize, Deserialize)]
368pub struct CanFrameWireData {
369    pub arb_id: u32,
370    pub len: u8,
371    pub flags: u8,
372    pub data: Vec<u8>,
373}
374
375#[derive(Debug, Clone, Serialize, Deserialize)]
376pub struct CanFrameData {
377    pub arb_id: u32,
378    pub len: u8,
379    pub flags: u8,
380    pub data_hex: String,
381}
382
383impl From<SimCanFrame> for CanFrameWireData {
384    fn from(value: SimCanFrame) -> Self {
385        Self {
386            arb_id: value.arb_id,
387            len: value.len,
388            flags: value.flags,
389            data: value.payload().to_vec(),
390        }
391    }
392}
393
394impl From<&SimCanFrame> for CanFrameWireData {
395    fn from(value: &SimCanFrame) -> Self {
396        Self {
397            arb_id: value.arb_id,
398            len: value.len,
399            flags: value.flags,
400            data: value.payload().to_vec(),
401        }
402    }
403}
404
405impl TryFrom<CanFrameWireData> for SimCanFrame {
406    type Error = String;
407
408    fn try_from(value: CanFrameWireData) -> Result<Self, Self::Error> {
409        if value.data.len() > 64 {
410            return Err(format!(
411                "CAN frame payload exceeds 64 bytes ({} bytes provided)",
412                value.data.len()
413            ));
414        }
415        let mut data = [0_u8; 64];
416        data[..value.data.len()].copy_from_slice(&value.data);
417        let len = usize::from(value.len);
418        if len != value.data.len() {
419            return Err(format!(
420                "CAN frame length {} does not match payload size {}",
421                value.len,
422                value.data.len()
423            ));
424        }
425        Ok(SimCanFrame {
426            arb_id: value.arb_id,
427            len: value.len,
428            flags: value.flags,
429            data,
430        })
431    }
432}
433
434#[derive(Debug, Clone, Serialize, Deserialize)]
435pub struct CanScheduleData {
436    pub job_id: String,
437    pub bus: String,
438    pub arb_id: u32,
439    pub data_hex: String,
440    pub flags: u8,
441    pub every_ticks: u64,
442    pub next_due_tick: u64,
443    pub enabled: bool,
444}
445
446#[derive(Debug, Clone, Serialize, Deserialize)]
447pub struct SharedChannelData {
448    pub id: u32,
449    pub name: String,
450    pub slot_count: u32,
451}
452
453#[derive(Debug, Clone, Serialize, Deserialize)]
454pub struct SharedSlotValueData {
455    pub slot_id: u32,
456    pub signal_type: SignalType,
457    pub value: SignalValue,
458}
459
460#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
461#[serde(rename_all = "snake_case")]
462pub enum RecipeStepKindData {
463    Set,
464    Step,
465    Print,
466    Speed,
467    Reset,
468    Sleep,
469    Assert,
470    ForIteration,
471}
472
473#[derive(Debug, Clone, Serialize, Deserialize)]
474pub struct RecipeStepResultData {
475    pub kind: RecipeStepKindData,
476    pub instance: Option<String>,
477    pub detail: String,
478}
479
480pub fn parse_duration_us(input: &str) -> Result<u64, ProtocolError> {
481    let trimmed = input.trim();
482    let (value_part, unit) = if let Some(v) = trimmed.strip_suffix("ms") {
483        (v.trim(), "ms")
484    } else if let Some(v) = trimmed.strip_suffix("us") {
485        (v.trim(), "us")
486    } else if let Some(v) = trimmed.strip_suffix('s') {
487        (v.trim(), "s")
488    } else {
489        return Err(ProtocolError::InvalidDuration(trimmed.to_string()));
490    };
491
492    let value: f64 = value_part
493        .parse()
494        .map_err(|_| ProtocolError::InvalidDuration(trimmed.to_string()))?;
495    if !value.is_finite() || value < 0.0 {
496        return Err(ProtocolError::InvalidDuration(trimmed.to_string()));
497    }
498
499    let us = match unit {
500        "s" => value * 1_000_000.0,
501        "ms" => value * 1_000.0,
502        "us" => value,
503        _ => unreachable!(),
504    };
505
506    if us > u64::MAX as f64 {
507        return Err(ProtocolError::InvalidDuration(trimmed.to_string()));
508    }
509    Ok(us.floor() as u64)
510}
511
512#[cfg(test)]
513mod tests {
514    use super::*;
515
516    #[test]
517    fn request_response_serde_roundtrip() {
518        let request = Request {
519            id: Uuid::new_v4(),
520            action: RequestAction::Instance(InstanceAction::Set {
521                writes: BTreeMap::from([
522                    ("hvac.power".to_string(), "true".to_string()),
523                    ("hvac.target_temp".to_string(), "21.5".to_string()),
524                ]),
525            }),
526        };
527        let encoded_request =
528            serde_json::to_string(&request).expect("request should serialize to json");
529        let decoded_request: Request =
530            serde_json::from_str(&encoded_request).expect("request should deserialize from json");
531        match decoded_request.action {
532            RequestAction::Instance(InstanceAction::Set { writes }) => {
533                assert_eq!(writes.len(), 2);
534            }
535            other => panic!("expected set action, got {other:?}"),
536        }
537
538        let response = Response::ok(request.id, ResponseData::SetResult { writes_applied: 2 });
539        let encoded_response =
540            serde_json::to_string(&response).expect("response should serialize to json");
541        let decoded_response: Response =
542            serde_json::from_str(&encoded_response).expect("response should deserialize from json");
543        assert!(decoded_response.success);
544        assert!(decoded_response.error.is_none());
545    }
546
547    #[test]
548    fn duration_parser_handles_units() {
549        assert_eq!(parse_duration_us("1s").expect("1s should parse"), 1_000_000);
550        assert_eq!(
551            parse_duration_us("250ms").expect("250ms should parse"),
552            250_000
553        );
554        assert_eq!(parse_duration_us("500us").expect("500us should parse"), 500);
555        assert_eq!(
556            parse_duration_us("0.5s").expect("0.5s should parse"),
557            500_000
558        );
559    }
560
561    #[test]
562    fn duration_parser_rejects_invalid_values() {
563        assert!(parse_duration_us("abc").is_err());
564        assert!(parse_duration_us("-1s").is_err());
565        assert!(parse_duration_us("1m").is_err());
566    }
567}