riscv_target_parser/
extension.rs

1use crate::Error;
2use std::collections::HashSet;
3
4/// RISC-V standard extensions
5#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
6pub enum Extension {
7    /// Base Integer Instruction Set
8    I,
9    /// Base Integer Instruction Set (embedded, only 16 registers)
10    E,
11    /// Integer Multiplication and Division
12    M,
13    /// Atomic Instructions
14    A,
15    /// Single-Precision Floating-Point
16    F,
17    /// Double-Precision Floating-Point
18    D,
19    /// Quad-Precision Floating-Point
20    Q,
21    /// Compressed Instructions
22    C,
23    /// Bit Manipulation
24    B,
25    /// Packed-SIMD Instructions
26    P,
27    /// Vector Operations
28    V,
29    /// Hypervisor
30    H,
31    /// Standard Z-type extension
32    Z(String),
33    /// Standard S-type extension
34    S(String),
35    /// Vendor extension
36    X(String),
37}
38
39impl Extension {
40    /// Determines if the extension is a base extension.
41    pub const fn is_base(&self) -> bool {
42        matches!(self, Self::I | Self::E)
43    }
44}
45
46impl std::fmt::Display for Extension {
47    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
48        let repr = match self {
49            Self::I => "i",
50            Self::E => "e",
51            Self::M => "m",
52            Self::A => "a",
53            Self::F => "f",
54            Self::D => "d",
55            Self::Q => "q",
56            Self::C => "c",
57            Self::B => "b",
58            Self::P => "p",
59            Self::V => "v",
60            Self::H => "h",
61            Self::Z(s) | Self::S(s) | Self::X(s) => s,
62        };
63        write!(f, "{repr}")
64    }
65}
66
67impl<'a> TryFrom<&'a str> for Extension {
68    type Error = Error<'a>;
69
70    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
71        match value {
72            "i" => Ok(Extension::I),
73            "e" => Ok(Extension::E),
74            "m" => Ok(Extension::M),
75            "a" => Ok(Extension::A),
76            "f" => Ok(Extension::F),
77            "d" => Ok(Extension::D),
78            "q" => Ok(Extension::Q),
79            "c" => Ok(Extension::C),
80            "b" => Ok(Extension::B),
81            "p" => Ok(Extension::P),
82            "v" => Ok(Extension::V),
83            "h" => Ok(Extension::H),
84            _ => {
85                if value.starts_with('z') {
86                    Ok(Extension::Z(value.to_string()))
87                } else if value.starts_with('s') {
88                    Ok(Extension::S(value.to_string()))
89                } else if value.starts_with('x') {
90                    Ok(Extension::X(value.to_string()))
91                } else {
92                    Err(Self::Error::UnknownExtension(value))
93                }
94            }
95        }
96    }
97}
98
99/// Collection of RISC-V extensions.
100#[derive(Debug, Clone, PartialEq, Eq)]
101pub struct Extensions {
102    extensions: HashSet<Extension>,
103}
104
105impl Extensions {
106    /// Returns a vector with the list of extensions. Extensions are sorted in canonical order.
107    ///
108    /// The canonical order is defined as follows:
109    /// 1. Base ISA (I or E)
110    /// 2. Standard non-base extensions (M, A, F, D, Q, C, B, P, V, H)
111    /// 3. Standard Z-type extensions (e.g., Zicsr)
112    /// 4. Standard S-type extensions (e.g., Ssccfg)
113    /// 5. Vendor X-type extensions (e.g., XSifivecdiscarddlone)
114    ///
115    /// Z, S, and X-type extensions are sorted by their string representation.
116    pub fn extensions(&self) -> Vec<Extension> {
117        let mut res = self.extensions.iter().cloned().collect::<Vec<_>>();
118        res.sort();
119        res
120    }
121
122    /// Returns the base extension (I or E) if present.
123    pub fn base_extension(&self) -> Option<Extension> {
124        if self.extensions.contains(&Extension::I) {
125            Some(Extension::I)
126        } else if self.extensions.contains(&Extension::E) {
127            Some(Extension::E)
128        } else {
129            None
130        }
131    }
132
133    /// Returns `true` if the collection contains the given extension.
134    pub fn contains(&self, extension: &Extension) -> bool {
135        self.extensions.contains(extension)
136    }
137
138    pub fn is_g(&self) -> bool {
139        self.extensions.contains(&Extension::I)
140            && self.extensions.contains(&Extension::M)
141            && self.extensions.contains(&Extension::A)
142            && self.extensions.contains(&Extension::F)
143            && self.extensions.contains(&Extension::D)
144    }
145
146    /// Adds an extension to the collection. Returns `true` if the extension was not present.
147    pub fn insert(&mut self, extension: Extension) -> bool {
148        self.extensions.insert(extension)
149    }
150
151    /// Removes an extension from the collection. Returns `true` if the extension was present.
152    pub fn remove(&mut self, extension: &Extension) -> bool {
153        self.extensions.remove(extension)
154    }
155}
156
157impl<'a> TryFrom<&'a str> for Extensions {
158    type Error = Error<'a>;
159
160    fn try_from(value: &'a str) -> Result<Self, Self::Error> {
161        let mut value = value;
162        let mut extensions = HashSet::new();
163
164        while !value.is_empty() {
165            let extension =
166                if value.starts_with("z") || value.starts_with("s") || value.starts_with("x") {
167                    match value.find('_') {
168                        Some(pos) => {
169                            let (ext, _) = value.split_at(pos);
170                            ext
171                        }
172                        None => value,
173                    }
174                } else {
175                    &value[0..1] // single character extension
176                };
177            value = value.trim_start_matches(extension).trim_start_matches("_");
178
179            match Extension::try_from(extension) {
180                Ok(ext) => {
181                    extensions.insert(ext);
182                }
183                Err(Self::Error::UnknownExtension(ext)) => {
184                    if ext == "g" {
185                        // G is a shorthand for IMAFD
186                        extensions.insert(Extension::I);
187                        extensions.insert(Extension::M);
188                        extensions.insert(Extension::A);
189                        extensions.insert(Extension::F);
190                        extensions.insert(Extension::D);
191                    } else {
192                        return Err(Self::Error::UnknownExtension(ext));
193                    }
194                }
195                _ => unreachable!(),
196            }
197        }
198        Ok(Extensions { extensions })
199    }
200}
201
202impl std::fmt::Display for Extensions {
203    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
204        let mut extensions = String::new();
205        let mut prev_zsx = false;
206        for ext in &self.extensions() {
207            if prev_zsx {
208                extensions.push('_');
209            }
210            extensions.push_str(ext.to_string().as_str());
211            prev_zsx = matches!(ext, Extension::Z(_) | Extension::S(_) | Extension::X(_));
212        }
213        match extensions.strip_prefix("imafd") {
214            Some(extensions) => write!(f, "g{extensions}"),
215            None => match extensions.strip_prefix("iemafd") {
216                Some(extensions) => write!(f, "ge{extensions}"),
217                None => write!(f, "{extensions}"),
218            },
219        }
220    }
221}
222
223#[cfg(test)]
224mod test {
225    use super::*;
226
227    #[test]
228    fn test_extension_try_from() {
229        assert_eq!(Extension::try_from("i"), Ok(Extension::I));
230        assert_eq!(Extension::try_from("e"), Ok(Extension::E));
231        assert_eq!(Extension::try_from("m"), Ok(Extension::M));
232        assert_eq!(Extension::try_from("a"), Ok(Extension::A));
233        assert_eq!(Extension::try_from("f"), Ok(Extension::F));
234        assert_eq!(Extension::try_from("d"), Ok(Extension::D));
235        assert_eq!(Extension::try_from("q"), Ok(Extension::Q));
236        assert_eq!(Extension::try_from("c"), Ok(Extension::C));
237        assert_eq!(Extension::try_from("b"), Ok(Extension::B));
238        assert_eq!(Extension::try_from("p"), Ok(Extension::P));
239        assert_eq!(Extension::try_from("v"), Ok(Extension::V));
240        assert_eq!(Extension::try_from("h"), Ok(Extension::H));
241        assert_eq!(
242            Extension::try_from("zicsr"),
243            Ok(Extension::Z("zicsr".to_string()))
244        );
245        assert_eq!(
246            Extension::try_from("ssccfg"),
247            Ok(Extension::S("ssccfg".to_string()))
248        );
249        assert_eq!(
250            Extension::try_from("xsifivecdiscarddlone"),
251            Ok(Extension::X("xsifivecdiscarddlone".to_string()))
252        );
253        assert_eq!(
254            Extension::try_from("unknown"),
255            Err(Error::UnknownExtension("unknown"))
256        );
257    }
258
259    #[test]
260    fn test_extension_to_string() {
261        assert_eq!(Extension::I.to_string(), "i");
262        assert_eq!(Extension::E.to_string(), "e");
263        assert_eq!(Extension::M.to_string(), "m");
264        assert_eq!(Extension::A.to_string(), "a");
265        assert_eq!(Extension::F.to_string(), "f");
266        assert_eq!(Extension::D.to_string(), "d");
267        assert_eq!(Extension::Q.to_string(), "q");
268        assert_eq!(Extension::C.to_string(), "c");
269        assert_eq!(Extension::B.to_string(), "b");
270        assert_eq!(Extension::P.to_string(), "p");
271        assert_eq!(Extension::V.to_string(), "v");
272        assert_eq!(Extension::H.to_string(), "h");
273        assert_eq!(Extension::Z("zicsr".to_string()).to_string(), "zicsr");
274        assert_eq!(Extension::S("ssccfg".to_string()).to_string(), "ssccfg");
275        assert_eq!(
276            Extension::X("xsifivecdiscarddlone".to_string()).to_string(),
277            "xsifivecdiscarddlone"
278        );
279    }
280
281    #[test]
282    fn test_extension_cmp() {
283        let mut extensions = vec![
284            Extension::I,
285            Extension::M,
286            Extension::A,
287            Extension::F,
288            Extension::D,
289            Extension::Q,
290            Extension::C,
291            Extension::B,
292            Extension::P,
293            Extension::V,
294            Extension::H,
295            Extension::Z("zicsr".to_string()),
296            Extension::S("ssccfg".to_string()),
297            Extension::X("xsifivecdiscarddlone".to_string()),
298        ];
299        extensions.reverse();
300        extensions.sort();
301        assert_eq!(
302            extensions,
303            vec![
304                Extension::I,
305                Extension::M,
306                Extension::A,
307                Extension::F,
308                Extension::D,
309                Extension::Q,
310                Extension::C,
311                Extension::B,
312                Extension::P,
313                Extension::V,
314                Extension::H,
315                Extension::Z("zicsr".to_string()),
316                Extension::S("ssccfg".to_string()),
317                Extension::X("xsifivecdiscarddlone".to_string()),
318            ]
319        );
320    }
321
322    #[test]
323    fn test_extensions_try_from() {
324        let mut try_extensions = Extensions::try_from("");
325        assert!(try_extensions.is_ok());
326        let mut extensions = try_extensions.unwrap();
327        assert!(extensions.extensions().is_empty());
328        assert!(extensions.base_extension().is_none());
329
330        try_extensions =
331            Extensions::try_from("giemafdqcbpvhxsifivecdiscarddlone_ssccfg_zicsr_zaamo_u");
332        assert!(try_extensions.is_err());
333        assert_eq!(try_extensions, Err(Error::UnknownExtension("u")));
334
335        try_extensions = Extensions::try_from("geqcbpvhxsifivecdiscarddlone_ssccfg_zicsr_zaamo_");
336        assert!(try_extensions.is_ok());
337        extensions = try_extensions.unwrap();
338        assert_eq!(
339            extensions.extensions(),
340            vec![
341                Extension::I,
342                Extension::E,
343                Extension::M,
344                Extension::A,
345                Extension::F,
346                Extension::D,
347                Extension::Q,
348                Extension::C,
349                Extension::B,
350                Extension::P,
351                Extension::V,
352                Extension::H,
353                Extension::Z("zaamo".to_string()),
354                Extension::Z("zicsr".to_string()),
355                Extension::S("ssccfg".to_string()),
356                Extension::X("xsifivecdiscarddlone".to_string()),
357            ]
358        );
359        assert_eq!(extensions.base_extension(), Some(Extension::I));
360
361        try_extensions =
362            Extensions::try_from("iemafdqcbpvhxsifivecdiscarddlone_ssccfg_zicsr_zaamo_");
363        assert!(try_extensions.is_ok());
364        extensions = try_extensions.unwrap();
365        assert_eq!(
366            extensions.extensions(),
367            vec![
368                Extension::I,
369                Extension::E,
370                Extension::M,
371                Extension::A,
372                Extension::F,
373                Extension::D,
374                Extension::Q,
375                Extension::C,
376                Extension::B,
377                Extension::P,
378                Extension::V,
379                Extension::H,
380                Extension::Z("zaamo".to_string()),
381                Extension::Z("zicsr".to_string()),
382                Extension::S("ssccfg".to_string()),
383                Extension::X("xsifivecdiscarddlone".to_string()),
384            ]
385        );
386        assert_eq!(extensions.base_extension(), Some(Extension::I));
387
388        try_extensions =
389            Extensions::try_from("emafdqcbpvhxsifivecdiscarddlone_ssccfg_zicsr_zaamo_");
390        assert!(try_extensions.is_ok());
391        extensions = try_extensions.unwrap();
392        assert_eq!(
393            extensions.extensions(),
394            vec![
395                Extension::E,
396                Extension::M,
397                Extension::A,
398                Extension::F,
399                Extension::D,
400                Extension::Q,
401                Extension::C,
402                Extension::B,
403                Extension::P,
404                Extension::V,
405                Extension::H,
406                Extension::Z("zaamo".to_string()),
407                Extension::Z("zicsr".to_string()),
408                Extension::S("ssccfg".to_string()),
409                Extension::X("xsifivecdiscarddlone".to_string()),
410            ]
411        );
412        assert_eq!(extensions.base_extension(), Some(Extension::E));
413    }
414
415    #[test]
416    fn test_extensions_insert_remove() {
417        let mut extensions = Extensions::try_from("gc").unwrap();
418
419        assert_eq!(extensions.extensions.len(), 6);
420        assert!(extensions.contains(&Extension::I));
421        assert!(extensions.contains(&Extension::M));
422        assert!(extensions.contains(&Extension::A));
423        assert!(extensions.contains(&Extension::F));
424        assert!(extensions.contains(&Extension::D));
425        assert!(extensions.contains(&Extension::C));
426        assert!(!extensions.contains(&Extension::E));
427        assert!(!extensions.contains(&Extension::Q));
428        assert_eq!(extensions.base_extension(), Some(Extension::I));
429
430        assert!(!extensions.insert(Extension::I));
431        assert!(!extensions.remove(&Extension::E));
432        assert_eq!(extensions.extensions.len(), 6);
433
434        assert!(extensions.insert(Extension::E));
435        assert_eq!(extensions.extensions.len(), 7);
436        assert!(extensions.contains(&Extension::E));
437        assert_eq!(extensions.base_extension(), Some(Extension::I));
438
439        assert!(extensions.remove(&Extension::I));
440        assert_eq!(extensions.extensions.len(), 6);
441        assert!(!extensions.contains(&Extension::I));
442        assert_eq!(extensions.base_extension(), Some(Extension::E));
443
444        assert!(extensions.remove(&Extension::E));
445        assert_eq!(extensions.extensions.len(), 5);
446        assert!(!extensions.contains(&Extension::E));
447        assert_eq!(extensions.base_extension(), None);
448    }
449
450    #[test]
451    fn test_extensions_to_string() {
452        let mut extensions = Extensions::try_from("imafdc").unwrap();
453        assert_eq!(extensions.to_string(), "gc");
454
455        extensions.insert(Extension::try_from("ssccfg").unwrap());
456        assert_eq!(extensions.to_string(), "gcssccfg");
457
458        extensions.insert(Extension::try_from("zicsr").unwrap());
459        assert_eq!(extensions.to_string(), "gczicsr_ssccfg");
460
461        extensions.insert(Extension::try_from("zaamo").unwrap());
462        assert_eq!(extensions.to_string(), "gczaamo_zicsr_ssccfg");
463
464        extensions.insert(Extension::try_from("xsifivecdiscarddlone").unwrap());
465        assert_eq!(
466            extensions.to_string(),
467            "gczaamo_zicsr_ssccfg_xsifivecdiscarddlone"
468        );
469
470        extensions.insert(Extension::try_from("e").unwrap());
471        assert_eq!(
472            extensions.to_string(),
473            "geczaamo_zicsr_ssccfg_xsifivecdiscarddlone"
474        );
475
476        extensions.remove(&Extension::I);
477        assert_eq!(
478            extensions.to_string(),
479            "emafdczaamo_zicsr_ssccfg_xsifivecdiscarddlone"
480        );
481
482        extensions.remove(&Extension::E);
483        assert_eq!(
484            extensions.to_string(),
485            "mafdczaamo_zicsr_ssccfg_xsifivecdiscarddlone"
486        );
487
488        extensions.insert(Extension::I);
489        assert_eq!(
490            extensions.to_string(),
491            "gczaamo_zicsr_ssccfg_xsifivecdiscarddlone"
492        );
493    }
494}