1use std::collections::HashMap;
10
11use anyhow::Result;
12use rlx_ir::hir::HirModule;
13use rlx_ir::{GraphModule, HirNodeId, Shape};
14
15use crate::context::{FlowCtx, FlowState};
16use crate::profile::CompileProfile;
17use crate::value::FlowValue;
18use crate::weight::WeightSource;
19
20pub struct Emit<'a> {
22 pub module: &'a mut GraphModule,
23 pub params: &'a mut HashMap<String, Vec<f32>>,
24 pub weights: &'a mut dyn WeightSource,
25 pub state: &'a mut FlowState,
26 pub profile: &'a CompileProfile,
27}
28
29impl<'a> Emit<'a> {
30 pub(crate) fn from_ctx(ctx: &'a mut FlowCtx<'_>) -> Self {
31 Self {
32 module: &mut ctx.module,
33 params: ctx.params,
34 weights: ctx.weights,
35 state: ctx.state,
36 profile: ctx.profile,
37 }
38 }
39
40 pub fn hir(&mut self) -> &mut HirModule {
41 self.module
42 .as_hir_mut()
43 .expect("flow context requires HIR stage")
44 }
45
46 pub fn load_param(&mut self, key: &str, transpose: bool) -> Result<HirNodeId> {
47 let cache_key = if transpose {
48 format!("{key}\0t")
49 } else {
50 key.to_string()
51 };
52 if let Some(&id) = self.state.loaded_params.get(&cache_key) {
53 return Ok(id);
54 }
55 let (data, shape) = self.weights.take(key, transpose)?;
56 let ir_shape = Shape::new(&shape, rlx_ir::DType::F32);
57 let id = self.hir().param(key, ir_shape);
58 self.params.insert(key.to_string(), data);
59 self.state.loaded_params.insert(cache_key, id);
60 Ok(id)
61 }
62
63 pub fn synth_param(&mut self, name: &str, data: Vec<f32>, shape: Shape) -> HirNodeId {
64 let id = self.hir().param(name, shape);
65 self.params.insert(name.to_string(), data);
66 id
67 }
68
69 pub fn synth_zeros(&mut self, name: &str, len: usize) -> HirNodeId {
70 self.synth_param(
71 name,
72 vec![0f32; len],
73 Shape::new(&[len], rlx_ir::DType::F32),
74 )
75 }
76
77 pub fn hir_and_params(&mut self) -> (&mut HirModule, &mut HashMap<String, Vec<f32>>) {
78 (
79 self.module
80 .as_hir_mut()
81 .expect("flow context requires HIR stage"),
82 self.params,
83 )
84 }
85
86 pub fn wrap(&self, id: HirNodeId, shape: Shape) -> FlowValue {
87 FlowValue::new(id, shape)
88 }
89
90 pub fn flow_input(&self, name: &str) -> Result<FlowValue> {
92 let (id, shape) = self
93 .state
94 .inputs
95 .get(name)
96 .ok_or_else(|| anyhow::anyhow!("flow input missing `{name}`"))?;
97 Ok(FlowValue::new(*id, shape.clone()))
98 }
99
100 pub fn set_named(&mut self, key: impl Into<String>, id: HirNodeId) {
101 self.state.named.insert(key.into(), id);
102 }
103
104 pub fn named(&self, key: &str) -> Result<HirNodeId> {
105 self.state
106 .named
107 .get(key)
108 .copied()
109 .ok_or_else(|| anyhow::anyhow!("named flow handle missing `{key}`"))
110 }
111}
112
113pub use crate::context::DecodeBindings;