1use std::collections::HashMap;
7
8use anyhow::Result;
9use rlx_ir::hir::{HirModule, HirNodeId};
10use rlx_ir::{DType, GraphModule, Shape};
11
12use crate::profile::CompileProfile;
13use crate::value::FlowValue;
14use crate::weight::WeightSource;
15
16#[derive(Debug, Clone, Copy)]
18pub struct GdnInputSlots {
19 pub q: HirNodeId,
20 pub k: HirNodeId,
21 pub v: HirNodeId,
22 pub g: HirNodeId,
23 pub beta: HirNodeId,
24}
25
26#[derive(Debug, Default)]
28pub struct FlowState {
29 pub rope_cos: Option<HirNodeId>,
30 pub rope_sin: Option<HirNodeId>,
31 pub zero_beta: Option<HirNodeId>,
32 pub embed_weight: Option<HirNodeId>,
33 pub hidden_shape: Option<Shape>,
34 pub decode: Option<DecodeBindings>,
35 pub residual_skip: Option<HirNodeId>,
36 pub residual_shape: Option<Shape>,
37 pub streams: HashMap<String, FlowValue>,
39 pub inputs: HashMap<String, (HirNodeId, Shape)>,
41 pub named: HashMap<String, HirNodeId>,
43 pub gdn: Option<GdnInputSlots>,
45 pub loaded_params: HashMap<String, HirNodeId>,
48}
49
50#[derive(Debug, Clone)]
52pub struct DecodeBindings {
53 pub cos: HirNodeId,
54 pub sin: HirNodeId,
55 pub mask: Option<HirNodeId>,
56 pub past_k: Vec<HirNodeId>,
57 pub past_v: Vec<HirNodeId>,
58}
59
60pub struct FlowCtx<'a> {
62 pub(crate) module: GraphModule,
63 pub(crate) params: &'a mut HashMap<String, Vec<f32>>,
64 pub(crate) weights: &'a mut dyn WeightSource,
65 pub(crate) profile: &'a CompileProfile,
66 pub(crate) state: &'a mut FlowState,
67}
68
69impl FlowCtx<'_> {
70 pub fn hir(&mut self) -> &mut HirModule {
71 self.module
72 .as_hir_mut()
73 .expect("flow context requires HIR stage")
74 }
75
76 pub fn node_shape(&self, id: HirNodeId) -> Result<Shape> {
77 Ok(self
78 .module
79 .as_hir()
80 .ok_or_else(|| anyhow::anyhow!("flow context requires HIR stage"))?
81 .node(id)
82 .shape
83 .clone())
84 }
85
86 pub fn load_param(&mut self, key: &str, transpose: bool) -> Result<HirNodeId> {
87 let cache_key = param_cache_key(key, transpose);
88 if let Some(&id) = self.state.loaded_params.get(&cache_key) {
89 return Ok(id);
90 }
91 let (data, shape) = self.weights.take(key, transpose)?;
92 let ir_shape = Shape::new(&shape, DType::F32);
93 let id = self.hir().param(key, ir_shape);
94 self.params.insert(key.to_string(), data);
95 self.state.loaded_params.insert(cache_key, id);
96 Ok(id)
97 }
98
99 pub fn synth_param(&mut self, name: &str, data: Vec<f32>, shape: Shape) -> HirNodeId {
100 let id = self.hir().param(name, shape);
101 self.params.insert(name.to_string(), data);
102 id
103 }
104
105 pub fn synth_zeros(&mut self, name: &str, len: usize) -> HirNodeId {
106 self.synth_param(name, vec![0f32; len], Shape::new(&[len], DType::F32))
107 }
108
109 pub fn input(&mut self, name: &str, shape: Shape) -> HirNodeId {
110 self.hir().input(name, shape)
111 }
112
113 pub fn wrap(&self, id: HirNodeId, shape: Shape) -> FlowValue {
114 FlowValue::new(id, shape)
115 }
116}
117
118fn param_cache_key(key: &str, transpose: bool) -> String {
119 if transpose {
120 format!("{key}\0t")
121 } else {
122 key.to_string()
123 }
124}