1use std::collections::HashMap;
22
23use anyhow::Result;
24use rlx_ir::hir::HirModule;
25use rlx_ir::{GraphModule, HirNodeId, Shape};
26
27use crate::context::{FlowCtx, FlowState};
28use crate::profile::CompileProfile;
29use crate::value::FlowValue;
30use crate::weight::WeightSource;
31
32pub struct Emit<'a> {
34 pub module: &'a mut GraphModule,
35 pub params: &'a mut HashMap<String, Vec<f32>>,
36 pub weights: &'a mut dyn WeightSource,
37 pub state: &'a mut FlowState,
38 pub profile: &'a CompileProfile,
39}
40
41impl<'a> Emit<'a> {
42 pub(crate) fn from_ctx(ctx: &'a mut FlowCtx<'_>) -> Self {
43 Self {
44 module: &mut ctx.module,
45 params: ctx.params,
46 weights: ctx.weights,
47 state: ctx.state,
48 profile: ctx.profile,
49 }
50 }
51
52 pub fn hir(&mut self) -> &mut HirModule {
53 self.module
54 .as_hir_mut()
55 .expect("flow context requires HIR stage")
56 }
57
58 pub fn load_param(&mut self, key: &str, transpose: bool) -> Result<HirNodeId> {
59 let cache_key = if transpose {
60 format!("{key}\0t")
61 } else {
62 key.to_string()
63 };
64 if let Some(&id) = self.state.loaded_params.get(&cache_key) {
65 return Ok(id);
66 }
67 let (data, shape) = self.weights.take(key, transpose)?;
68 let ir_shape = Shape::new(&shape, rlx_ir::DType::F32);
69 let id = self.hir().param(key, ir_shape);
70 self.params.insert(key.to_string(), data);
71 self.state.loaded_params.insert(cache_key, id);
72 Ok(id)
73 }
74
75 pub fn synth_param(&mut self, name: &str, data: Vec<f32>, shape: Shape) -> HirNodeId {
76 let id = self.hir().param(name, shape);
77 self.params.insert(name.to_string(), data);
78 id
79 }
80
81 pub fn synth_zeros(&mut self, name: &str, len: usize) -> HirNodeId {
82 self.synth_param(
83 name,
84 vec![0f32; len],
85 Shape::new(&[len], rlx_ir::DType::F32),
86 )
87 }
88
89 pub fn hir_and_params(&mut self) -> (&mut HirModule, &mut HashMap<String, Vec<f32>>) {
90 (
91 self.module
92 .as_hir_mut()
93 .expect("flow context requires HIR stage"),
94 self.params,
95 )
96 }
97
98 pub fn wrap(&self, id: HirNodeId, shape: Shape) -> FlowValue {
99 FlowValue::new(id, shape)
100 }
101
102 pub fn flow_input(&self, name: &str) -> Result<FlowValue> {
104 let (id, shape) = self
105 .state
106 .inputs
107 .get(name)
108 .ok_or_else(|| anyhow::anyhow!("flow input missing `{name}`"))?;
109 Ok(FlowValue::new(*id, shape.clone()))
110 }
111
112 pub fn set_named(&mut self, key: impl Into<String>, id: HirNodeId) {
113 self.state.named.insert(key.into(), id);
114 }
115
116 pub fn named(&self, key: &str) -> Result<HirNodeId> {
117 self.state
118 .named
119 .get(key)
120 .copied()
121 .ok_or_else(|| anyhow::anyhow!("named flow handle missing `{key}`"))
122 }
123}
124
125pub use crate::context::DecodeBindings;