Skip to main content

seedlink_rs_protocol/
command.rs

1use crate::error::{Result, SeedlinkError};
2use crate::info::InfoLevel;
3use crate::sequence::SequenceNumber;
4use crate::version::ProtocolVersion;
5
6#[derive(Clone, Debug, PartialEq, Eq)]
7pub enum Command {
8    // Both v3 + v4
9    Hello,
10    Station {
11        station: String,
12        network: String,
13    },
14    Select {
15        pattern: String,
16    },
17    Data {
18        sequence: Option<SequenceNumber>,
19        start: Option<String>,
20        end: Option<String>,
21    },
22    End,
23    Bye,
24    Info {
25        level: InfoLevel,
26    },
27
28    // v3 only
29    Batch,
30    Fetch {
31        sequence: Option<SequenceNumber>,
32    },
33    Time {
34        start: String,
35        end: Option<String>,
36    },
37    Cat,
38
39    // v4 only
40    SlProto {
41        version: String,
42    },
43    Auth {
44        value: String,
45    },
46    UserAgent {
47        description: String,
48    },
49    EndFetch,
50}
51
52impl Command {
53    /// Parse a command from a text line (version-agnostic).
54    ///
55    /// The line should NOT include the trailing `\r\n`.
56    pub fn parse(line: &str) -> Result<Self> {
57        let line = line.trim_end_matches('\n').trim_end_matches('\r');
58        let mut parts = line.split_whitespace();
59        let keyword = parts
60            .next()
61            .ok_or_else(|| SeedlinkError::InvalidCommand("empty command".into()))?;
62
63        match keyword.to_uppercase().as_str() {
64            "HELLO" => {
65                reject_extra_args(&mut parts, "HELLO")?;
66                Ok(Self::Hello)
67            }
68            "STATION" => {
69                let first = parts.next().ok_or_else(|| {
70                    SeedlinkError::InvalidCommand("STATION requires arguments".into())
71                })?;
72                // v4 uses "NET_STA" combined, v3 uses "STA NET" separate
73                if let Some(net) = parts.next() {
74                    reject_extra_args(&mut parts, "STATION")?;
75                    Ok(Self::Station {
76                        station: first.to_owned(),
77                        network: net.to_owned(),
78                    })
79                } else {
80                    // v4 combined format: NET_STA
81                    if let Some((net, sta)) = first.split_once('_') {
82                        Ok(Self::Station {
83                            station: sta.to_owned(),
84                            network: net.to_owned(),
85                        })
86                    } else {
87                        Err(SeedlinkError::InvalidCommand(format!(
88                            "STATION: expected 'STA NET' or 'NET_STA', got {first:?}"
89                        )))
90                    }
91                }
92            }
93            "SELECT" => {
94                let pattern = parts.next().ok_or_else(|| {
95                    SeedlinkError::InvalidCommand("SELECT requires a pattern".into())
96                })?;
97                reject_extra_args(&mut parts, "SELECT")?;
98                Ok(Self::Select {
99                    pattern: pattern.to_owned(),
100                })
101            }
102            "DATA" => {
103                let seq_str = parts.next();
104                let start = parts.next().map(|s| s.to_owned());
105                let end = parts.next().map(|s| s.to_owned());
106                let sequence = seq_str.map(parse_sequence).transpose()?;
107                Ok(Self::Data {
108                    sequence,
109                    start,
110                    end,
111                })
112            }
113            "END" => {
114                reject_extra_args(&mut parts, "END")?;
115                Ok(Self::End)
116            }
117            "BYE" => {
118                reject_extra_args(&mut parts, "BYE")?;
119                Ok(Self::Bye)
120            }
121            "INFO" => {
122                let level_str = parts
123                    .next()
124                    .ok_or_else(|| SeedlinkError::InvalidCommand("INFO requires a level".into()))?;
125                reject_extra_args(&mut parts, "INFO")?;
126                let level = InfoLevel::parse(level_str)?;
127                Ok(Self::Info { level })
128            }
129            "BATCH" => {
130                reject_extra_args(&mut parts, "BATCH")?;
131                Ok(Self::Batch)
132            }
133            "FETCH" => {
134                let seq_str = parts.next();
135                let sequence = seq_str.map(parse_sequence).transpose()?;
136                Ok(Self::Fetch { sequence })
137            }
138            "TIME" => {
139                let start = parts
140                    .next()
141                    .ok_or_else(|| SeedlinkError::InvalidCommand("TIME requires start".into()))?
142                    .to_owned();
143                let end = parts.next().map(|s| s.to_owned());
144                Ok(Self::Time { start, end })
145            }
146            "CAT" => {
147                reject_extra_args(&mut parts, "CAT")?;
148                Ok(Self::Cat)
149            }
150            "SLPROTO" => {
151                let version = parts
152                    .next()
153                    .ok_or_else(|| {
154                        SeedlinkError::InvalidCommand("SLPROTO requires version".into())
155                    })?
156                    .to_owned();
157                reject_extra_args(&mut parts, "SLPROTO")?;
158                Ok(Self::SlProto { version })
159            }
160            "AUTH" => {
161                // AUTH value may contain spaces (e.g. "AUTH USERPASS user pass")
162                let rest: Vec<&str> = parts.collect();
163                if rest.is_empty() {
164                    return Err(SeedlinkError::InvalidCommand(
165                        "AUTH requires a value".into(),
166                    ));
167                }
168                Ok(Self::Auth {
169                    value: rest.join(" "),
170                })
171            }
172            "USERAGENT" => {
173                let rest: Vec<&str> = parts.collect();
174                if rest.is_empty() {
175                    return Err(SeedlinkError::InvalidCommand(
176                        "USERAGENT requires a description".into(),
177                    ));
178                }
179                Ok(Self::UserAgent {
180                    description: rest.join(" "),
181                })
182            }
183            "ENDFETCH" => {
184                reject_extra_args(&mut parts, "ENDFETCH")?;
185                Ok(Self::EndFetch)
186            }
187            _ => Err(SeedlinkError::InvalidCommand(format!(
188                "unknown command: {keyword:?}"
189            ))),
190        }
191    }
192
193    /// Serialize to wire bytes for the given protocol version.
194    ///
195    /// Returns `Err(VersionMismatch)` if the command is not valid for the version.
196    pub fn to_bytes(&self, version: ProtocolVersion) -> Result<Vec<u8>> {
197        if !self.is_valid_for(version) {
198            return Err(SeedlinkError::VersionMismatch {
199                command: self.command_name(),
200                version,
201            });
202        }
203        let line = self.format_line(version);
204        Ok(format!("{line}\r\n").into_bytes())
205    }
206
207    /// Check if this command is valid for the given protocol version.
208    pub fn is_valid_for(&self, version: ProtocolVersion) -> bool {
209        match self {
210            Self::Hello
211            | Self::Station { .. }
212            | Self::Select { .. }
213            | Self::Data { .. }
214            | Self::End
215            | Self::Bye
216            | Self::Info { .. } => true,
217            Self::Batch | Self::Fetch { .. } | Self::Time { .. } | Self::Cat => {
218                version == ProtocolVersion::V3
219            }
220            Self::SlProto { .. } | Self::Auth { .. } | Self::UserAgent { .. } | Self::EndFetch => {
221                version == ProtocolVersion::V4
222            }
223        }
224    }
225
226    fn command_name(&self) -> &'static str {
227        match self {
228            Self::Hello => "HELLO",
229            Self::Station { .. } => "STATION",
230            Self::Select { .. } => "SELECT",
231            Self::Data { .. } => "DATA",
232            Self::End => "END",
233            Self::Bye => "BYE",
234            Self::Info { .. } => "INFO",
235            Self::Batch => "BATCH",
236            Self::Fetch { .. } => "FETCH",
237            Self::Time { .. } => "TIME",
238            Self::Cat => "CAT",
239            Self::SlProto { .. } => "SLPROTO",
240            Self::Auth { .. } => "AUTH",
241            Self::UserAgent { .. } => "USERAGENT",
242            Self::EndFetch => "ENDFETCH",
243        }
244    }
245
246    fn format_line(&self, version: ProtocolVersion) -> String {
247        match self {
248            Self::Hello => "HELLO".into(),
249            Self::Station { station, network } => match version {
250                ProtocolVersion::V3 => format!("STATION {station} {network}"),
251                ProtocolVersion::V4 => format!("STATION {network}_{station}"),
252            },
253            Self::Select { pattern } => format!("SELECT {pattern}"),
254            Self::Data {
255                sequence,
256                start,
257                end,
258            } => {
259                let mut s = "DATA".to_owned();
260                if let Some(seq) = sequence {
261                    s.push(' ');
262                    s.push_str(&format_sequence(*seq, version));
263                }
264                if let Some(start_time) = start {
265                    s.push(' ');
266                    s.push_str(start_time);
267                }
268                if let Some(end_time) = end {
269                    s.push(' ');
270                    s.push_str(end_time);
271                }
272                s
273            }
274            Self::End => "END".into(),
275            Self::Bye => "BYE".into(),
276            Self::Info { level } => format!("INFO {}", level.as_str()),
277            Self::Batch => "BATCH".into(),
278            Self::Fetch { sequence } => match sequence {
279                Some(seq) => format!("FETCH {}", format_sequence(*seq, version)),
280                None => "FETCH".into(),
281            },
282            Self::Time { start, end } => match end {
283                Some(e) => format!("TIME {start} {e}"),
284                None => format!("TIME {start}"),
285            },
286            Self::Cat => "CAT".into(),
287            Self::SlProto { version: v } => format!("SLPROTO {v}"),
288            Self::Auth { value } => format!("AUTH {value}"),
289            Self::UserAgent { description } => format!("USERAGENT {description}"),
290            Self::EndFetch => "ENDFETCH".into(),
291        }
292    }
293}
294
295/// Parse a sequence number from either hex (v3) or decimal (v4) format.
296fn parse_sequence(s: &str) -> Result<SequenceNumber> {
297    // Try v3 hex first (exactly 6 hex chars), then fall back to decimal
298    if s.len() == 6 && s.chars().all(|c| c.is_ascii_hexdigit()) {
299        SequenceNumber::from_v3_hex(s)
300    } else {
301        SequenceNumber::from_v4_decimal(s)
302    }
303}
304
305/// Format a sequence number for the given protocol version.
306fn format_sequence(seq: SequenceNumber, version: ProtocolVersion) -> String {
307    match version {
308        ProtocolVersion::V3 => seq.to_v3_hex(),
309        ProtocolVersion::V4 => seq.to_v4_decimal(),
310    }
311}
312
313fn reject_extra_args(parts: &mut std::str::SplitWhitespace<'_>, command: &str) -> Result<()> {
314    if parts.next().is_some() {
315        Err(SeedlinkError::InvalidCommand(format!(
316            "{command}: unexpected extra arguments"
317        )))
318    } else {
319        Ok(())
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    #[test]
328    fn parse_hello() {
329        assert_eq!(Command::parse("HELLO").unwrap(), Command::Hello);
330    }
331
332    #[test]
333    fn parse_hello_case_insensitive() {
334        assert_eq!(Command::parse("hello").unwrap(), Command::Hello);
335    }
336
337    #[test]
338    fn parse_station_v3() {
339        assert_eq!(
340            Command::parse("STATION ANMO IU").unwrap(),
341            Command::Station {
342                station: "ANMO".into(),
343                network: "IU".into(),
344            }
345        );
346    }
347
348    #[test]
349    fn parse_station_v4() {
350        assert_eq!(
351            Command::parse("STATION IU_ANMO").unwrap(),
352            Command::Station {
353                station: "ANMO".into(),
354                network: "IU".into(),
355            }
356        );
357    }
358
359    #[test]
360    fn parse_select() {
361        assert_eq!(
362            Command::parse("SELECT ??.BHZ").unwrap(),
363            Command::Select {
364                pattern: "??.BHZ".into(),
365            }
366        );
367    }
368
369    #[test]
370    fn parse_data_no_args() {
371        assert_eq!(
372            Command::parse("DATA").unwrap(),
373            Command::Data {
374                sequence: None,
375                start: None,
376                end: None,
377            }
378        );
379    }
380
381    #[test]
382    fn parse_data_with_hex_seq() {
383        let cmd = Command::parse("DATA 00001A").unwrap();
384        assert_eq!(
385            cmd,
386            Command::Data {
387                sequence: Some(SequenceNumber::new(26)),
388                start: None,
389                end: None,
390            }
391        );
392    }
393
394    #[test]
395    fn parse_data_with_decimal_seq() {
396        let cmd = Command::parse("DATA 26").unwrap();
397        assert_eq!(
398            cmd,
399            Command::Data {
400                sequence: Some(SequenceNumber::new(26)),
401                start: None,
402                end: None,
403            }
404        );
405    }
406
407    #[test]
408    fn parse_end() {
409        assert_eq!(Command::parse("END").unwrap(), Command::End);
410    }
411
412    #[test]
413    fn parse_bye() {
414        assert_eq!(Command::parse("BYE").unwrap(), Command::Bye);
415    }
416
417    #[test]
418    fn parse_info() {
419        assert_eq!(
420            Command::parse("INFO ID").unwrap(),
421            Command::Info {
422                level: InfoLevel::Id,
423            }
424        );
425    }
426
427    #[test]
428    fn parse_batch() {
429        assert_eq!(Command::parse("BATCH").unwrap(), Command::Batch);
430    }
431
432    #[test]
433    fn parse_fetch_no_seq() {
434        assert_eq!(
435            Command::parse("FETCH").unwrap(),
436            Command::Fetch { sequence: None }
437        );
438    }
439
440    #[test]
441    fn parse_fetch_with_seq() {
442        let cmd = Command::parse("FETCH 00004F").unwrap();
443        assert_eq!(
444            cmd,
445            Command::Fetch {
446                sequence: Some(SequenceNumber::new(0x4F))
447            }
448        );
449    }
450
451    #[test]
452    fn parse_time() {
453        assert_eq!(
454            Command::parse("TIME 2024,1,15,0,0,0").unwrap(),
455            Command::Time {
456                start: "2024,1,15,0,0,0".into(),
457                end: None,
458            }
459        );
460    }
461
462    #[test]
463    fn parse_time_with_end() {
464        assert_eq!(
465            Command::parse("TIME 2024,1,15,0,0,0 2024,1,16,0,0,0").unwrap(),
466            Command::Time {
467                start: "2024,1,15,0,0,0".into(),
468                end: Some("2024,1,16,0,0,0".into()),
469            }
470        );
471    }
472
473    #[test]
474    fn parse_cat() {
475        assert_eq!(Command::parse("CAT").unwrap(), Command::Cat);
476    }
477
478    #[test]
479    fn parse_slproto() {
480        assert_eq!(
481            Command::parse("SLPROTO 4.0").unwrap(),
482            Command::SlProto {
483                version: "4.0".into(),
484            }
485        );
486    }
487
488    #[test]
489    fn parse_auth() {
490        assert_eq!(
491            Command::parse("AUTH USERPASS user pass").unwrap(),
492            Command::Auth {
493                value: "USERPASS user pass".into(),
494            }
495        );
496    }
497
498    #[test]
499    fn parse_useragent() {
500        assert_eq!(
501            Command::parse("USERAGENT seedlink-rs/0.1").unwrap(),
502            Command::UserAgent {
503                description: "seedlink-rs/0.1".into(),
504            }
505        );
506    }
507
508    #[test]
509    fn parse_endfetch() {
510        assert_eq!(Command::parse("ENDFETCH").unwrap(), Command::EndFetch);
511    }
512
513    #[test]
514    fn parse_empty_error() {
515        assert!(Command::parse("").is_err());
516    }
517
518    #[test]
519    fn parse_unknown_error() {
520        assert!(Command::parse("FOOBAR").is_err());
521    }
522
523    #[test]
524    fn parse_trailing_crlf() {
525        assert_eq!(Command::parse("HELLO\r\n").unwrap(), Command::Hello);
526    }
527
528    #[test]
529    fn to_bytes_hello() {
530        let bytes = Command::Hello.to_bytes(ProtocolVersion::V3).unwrap();
531        assert_eq!(bytes, b"HELLO\r\n");
532    }
533
534    #[test]
535    fn to_bytes_station_v3() {
536        let cmd = Command::Station {
537            station: "ANMO".into(),
538            network: "IU".into(),
539        };
540        assert_eq!(
541            cmd.to_bytes(ProtocolVersion::V3).unwrap(),
542            b"STATION ANMO IU\r\n"
543        );
544    }
545
546    #[test]
547    fn to_bytes_station_v4() {
548        let cmd = Command::Station {
549            station: "ANMO".into(),
550            network: "IU".into(),
551        };
552        assert_eq!(
553            cmd.to_bytes(ProtocolVersion::V4).unwrap(),
554            b"STATION IU_ANMO\r\n"
555        );
556    }
557
558    #[test]
559    fn to_bytes_data_v3_with_seq() {
560        let cmd = Command::Data {
561            sequence: Some(SequenceNumber::new(26)),
562            start: None,
563            end: None,
564        };
565        assert_eq!(
566            cmd.to_bytes(ProtocolVersion::V3).unwrap(),
567            b"DATA 00001A\r\n"
568        );
569    }
570
571    #[test]
572    fn to_bytes_data_v4_with_seq() {
573        let cmd = Command::Data {
574            sequence: Some(SequenceNumber::new(26)),
575            start: None,
576            end: None,
577        };
578        assert_eq!(cmd.to_bytes(ProtocolVersion::V4).unwrap(), b"DATA 26\r\n");
579    }
580
581    #[test]
582    fn version_mismatch_batch_v4() {
583        let result = Command::Batch.to_bytes(ProtocolVersion::V4);
584        assert!(result.is_err());
585    }
586
587    #[test]
588    fn version_mismatch_slproto_v3() {
589        let cmd = Command::SlProto {
590            version: "4.0".into(),
591        };
592        assert!(cmd.to_bytes(ProtocolVersion::V3).is_err());
593    }
594
595    #[test]
596    fn is_valid_for_both() {
597        assert!(Command::Hello.is_valid_for(ProtocolVersion::V3));
598        assert!(Command::Hello.is_valid_for(ProtocolVersion::V4));
599    }
600
601    #[test]
602    fn is_valid_for_v3_only() {
603        assert!(Command::Batch.is_valid_for(ProtocolVersion::V3));
604        assert!(!Command::Batch.is_valid_for(ProtocolVersion::V4));
605    }
606
607    #[test]
608    fn is_valid_for_v4_only() {
609        assert!(!Command::EndFetch.is_valid_for(ProtocolVersion::V3));
610        assert!(Command::EndFetch.is_valid_for(ProtocolVersion::V4));
611    }
612
613    #[test]
614    fn roundtrip_v3() {
615        let commands = vec![
616            Command::Hello,
617            Command::Station {
618                station: "ANMO".into(),
619                network: "IU".into(),
620            },
621            Command::Select {
622                pattern: "??.BHZ".into(),
623            },
624            Command::Data {
625                sequence: Some(SequenceNumber::new(0x1A)),
626                start: None,
627                end: None,
628            },
629            Command::End,
630            Command::Bye,
631            Command::Info {
632                level: InfoLevel::Id,
633            },
634            Command::Batch,
635            Command::Cat,
636        ];
637        for cmd in commands {
638            let bytes = cmd.to_bytes(ProtocolVersion::V3).unwrap();
639            let line = std::str::from_utf8(&bytes).unwrap();
640            let parsed = Command::parse(line).unwrap();
641            assert_eq!(parsed, cmd, "roundtrip failed for {cmd:?}");
642        }
643    }
644
645    #[test]
646    fn roundtrip_v4() {
647        let commands = vec![
648            Command::Hello,
649            Command::Station {
650                station: "ANMO".into(),
651                network: "IU".into(),
652            },
653            Command::Data {
654                sequence: Some(SequenceNumber::new(26)),
655                start: None,
656                end: None,
657            },
658            Command::End,
659            Command::Bye,
660            Command::SlProto {
661                version: "4.0".into(),
662            },
663            Command::EndFetch,
664        ];
665        for cmd in commands {
666            let bytes = cmd.to_bytes(ProtocolVersion::V4).unwrap();
667            let line = std::str::from_utf8(&bytes).unwrap();
668            let parsed = Command::parse(line).unwrap();
669            assert_eq!(parsed, cmd, "roundtrip failed for {cmd:?}");
670        }
671    }
672}