device_driver_generation/
lib.rs

1#![doc = include_str!(concat!("../", env!("CARGO_PKG_README")))]
2
3use std::{
4    io::{Read, Write},
5    process::Stdio,
6};
7
8use anyhow::ensure;
9use itertools::Itertools;
10
11#[cfg(feature = "dsl")]
12mod dsl_hir;
13mod lir;
14#[cfg(feature = "manifest")]
15mod manifest;
16pub mod mir;
17
18/// Transform the tokens of the DSL lang to the generated device driver (or a compile error).
19///
20/// The `driver_name` arg is used to name the root block of the driver.
21/// It should be given in `PascalCase` form.
22#[cfg(feature = "dsl")]
23pub fn transform_dsl(input: proc_macro2::TokenStream, driver_name: &str) -> String {
24    let mir = match _private_transform_dsl_mir(input) {
25        Ok(mir) => mir,
26        Err(e) => return e.into_compile_error().to_string(),
27    };
28
29    transform_mir(mir, driver_name)
30}
31
32#[doc(hidden)]
33#[cfg(feature = "dsl")]
34pub fn _private_transform_dsl_mir(
35    input: proc_macro2::TokenStream,
36) -> Result<mir::Device, syn::Error> {
37    // Construct the HIR
38    let hir = syn::parse2::<dsl_hir::Device>(input)?;
39
40    // Transform into MIR
41    let mir = dsl_hir::mir_transform::transform(hir)?;
42
43    Ok(mir)
44}
45
46/// Transform the json string to the generated device driver (or a compile error).
47///
48/// The `driver_name` arg is used to name the root block of the driver.
49/// It should be given in `PascalCase` form.
50#[cfg(feature = "json")]
51pub fn transform_json(source: &str, driver_name: &str) -> String {
52    let mir = match _private_transform_json_mir(source) {
53        Ok(mir) => mir,
54        Err(e) => return anyhow_error_to_compile_error(e),
55    };
56
57    transform_mir(mir, driver_name)
58}
59
60#[doc(hidden)]
61#[cfg(feature = "json")]
62pub fn _private_transform_json_mir(source: &str) -> anyhow::Result<mir::Device> {
63    let value = dd_manifest_tree::parse_manifest::<dd_manifest_tree::JsonValue>(source)?;
64    let mir = manifest::transform(value)?;
65
66    Ok(mir)
67}
68
69/// Transform the yaml string to the generated device driver (or a compile error).
70///
71/// The `driver_name` arg is used to name the root block of the driver.
72/// It should be given in `PascalCase` form.
73#[cfg(feature = "yaml")]
74pub fn transform_yaml(source: &str, driver_name: &str) -> String {
75    let mir = match _private_transform_yaml_mir(source) {
76        Ok(mir) => mir,
77        Err(e) => return anyhow_error_to_compile_error(e),
78    };
79
80    transform_mir(mir, driver_name)
81}
82
83#[doc(hidden)]
84#[cfg(feature = "yaml")]
85pub fn _private_transform_yaml_mir(source: &str) -> anyhow::Result<mir::Device> {
86    let value = dd_manifest_tree::parse_manifest::<dd_manifest_tree::YamlValue>(source)?;
87    let mir = manifest::transform(value)?;
88
89    Ok(mir)
90}
91
92/// Transform the toml string to the generated device driver (or a compile error).
93///
94/// The `driver_name` arg is used to name the root block of the driver.
95/// It should be given in `PascalCase` form.
96#[cfg(feature = "toml")]
97pub fn transform_toml(source: &str, driver_name: &str) -> String {
98    let mir = match _private_transform_toml_mir(source) {
99        Ok(mir) => mir,
100        Err(e) => return anyhow_error_to_compile_error(e),
101    };
102
103    transform_mir(mir, driver_name)
104}
105
106#[doc(hidden)]
107#[cfg(feature = "toml")]
108pub fn _private_transform_toml_mir(source: &str) -> anyhow::Result<mir::Device> {
109    let value = dd_manifest_tree::parse_manifest::<dd_manifest_tree::TomlValue>(source)?;
110    let mir = manifest::transform(value)?;
111
112    Ok(mir)
113}
114
115fn transform_mir(mut mir: mir::Device, driver_name: &str) -> String {
116    // Run the MIR passes
117    match mir::passes::run_passes(&mut mir) {
118        Ok(()) => {}
119        Err(e) => return anyhow_error_to_compile_error(e),
120    }
121
122    // Transform into LIR
123    let mut lir = match mir::lir_transform::transform(mir, driver_name) {
124        Ok(lir) => lir,
125        Err(e) => return anyhow_error_to_compile_error(e),
126    };
127
128    // Run the LIR passes
129    match lir::passes::run_passes(&mut lir) {
130        Ok(()) => {}
131        Err(e) => return anyhow_error_to_compile_error(e),
132    };
133
134    // Transform into Rust source token output
135    let output = lir::code_transform::DeviceTemplateRust::new(&lir).to_string();
136
137    match format_code(&output) {
138        Ok(formatted_output) => formatted_output,
139        Err(e) => format!(
140            "{}\n\n{output}",
141            e.to_string().lines().map(|e| format!("// {e}")).join("\n")
142        ),
143    }
144}
145
146fn anyhow_error_to_compile_error(error: anyhow::Error) -> String {
147    syn::Error::new(proc_macro2::Span::call_site(), format!("{error:#}"))
148        .into_compile_error()
149        .to_string()
150}
151
152fn format_code(input: &str) -> Result<String, anyhow::Error> {
153    let mut cmd = std::process::Command::new("rustfmt");
154
155    cmd.args(["--edition", "2024"])
156        .args(["--config", "newline_style=native"])
157        .stdin(Stdio::piped())
158        .stdout(Stdio::piped())
159        .stderr(Stdio::piped());
160
161    let mut child = cmd.spawn()?;
162    let mut child_stdin = child.stdin.take().unwrap();
163    let mut child_stdout = child.stdout.take().unwrap();
164
165    // Write to stdin in a new thread, so that we can read from stdout on this
166    // thread. This keeps the child from blocking on writing to its stdout which
167    // might block us from writing to its stdin.
168    let output = std::thread::scope(|s| {
169        s.spawn(|| {
170            child_stdin.write_all(input.as_bytes()).unwrap();
171            child_stdin.flush().unwrap();
172            drop(child_stdin);
173        });
174        let handle: std::thread::ScopedJoinHandle<'_, Result<Vec<u8>, anyhow::Error>> =
175            s.spawn(|| {
176                let mut output = Vec::new();
177                child_stdout.read_to_end(&mut output)?;
178                Ok(output)
179            });
180
181        handle.join()
182    });
183
184    let status = child.wait()?;
185    ensure!(
186        status.success(),
187        "rustfmt exited unsuccesfully ({status}):\n{}",
188        child
189            .stderr
190            .map(|mut stderr| {
191                let mut err = String::new();
192                stderr.read_to_string(&mut err).unwrap();
193                err
194            })
195            .unwrap_or_default()
196    );
197
198    let output = match output {
199        Ok(output) => output,
200        Err(e) => std::panic::resume_unwind(e),
201    };
202
203    Ok(String::from_utf8(output?)?)
204}