1use bmi_rs::bmi::{Bmi, BmiResult, Location, RefValues, ValueType, Values, register_model};
2use burn::nn::LstmState;
3use burn::prelude::*;
4use burn::record::{FullPrecisionSettings, Recorder};
5use burn_import::pytorch::PyTorchFileRecorder;
6use glob::glob;
7use serde::{Deserialize, Serialize};
8use std::collections::HashMap;
9use std::fs;
10
11use std::path::Path;
12
13mod nextgen_lstm;
14mod python;
15use nextgen_lstm::{NextgenLstm, vec_to_tensor};
16use python::convert_model;
17
18#[derive(Debug, Serialize, Deserialize)]
19struct ModelMetadata {
20 input_size: usize,
21 hidden_size: usize,
22 output_size: usize,
23 input_names: Vec<String>,
24 output_names: Vec<String>,
25}
26
27#[derive(Debug, Serialize, Deserialize)]
28struct TrainingScalars {
29 input_mean: Vec<f32>,
30 input_std: Vec<f32>,
31 output_mean: f32,
32 output_std: f32,
33}
34
35struct ModelInstance<B: Backend> {
37 model: NextgenLstm<B>,
38 metadata: ModelMetadata,
39 scalars: TrainingScalars,
40 lstm_state: Option<LstmState<B, 2>>,
41}
42
43macro_rules! match_var {
45 ($name:expr, $($pattern:pat => $result:expr),+ $(,)?) => {
46 match $name {
47 $($pattern => $result,)+
48 _ => Err(Box::new(std::io::Error::new(
49 std::io::ErrorKind::NotFound,
50 format!("Variable {} not found", $name)
51 )))
52 }
53 };
54}
55
56pub struct LstmBmi<B: Backend> {
57 models: Vec<ModelInstance<B>>,
59 device: B::Device,
60
61 config_path: String,
63 area_sqkm: f32,
64 output_scale_factor_cms: f32,
65
66 variables: HashMap<String, Vec<f64>>,
68
69 input_var_names: Vec<&'static str>,
71 output_var_names: Vec<&'static str>,
72
73 current_time: f64,
75 start_time: f64,
76 end_time: f64,
77 time_step: f64,
78}
79
80impl<B: Backend> LstmBmi<B> {
81 pub fn new(device: B::Device) -> Self {
82 let input_vars = vec![
84 "atmosphere_water__liquid_equivalent_precipitation_rate",
85 "land_surface_air__temperature",
86 "land_surface_radiation~incoming~longwave__energy_flux",
87 "land_surface_air__pressure",
88 "atmosphere_air_water~vapor__relative_saturation",
89 "land_surface_radiation~incoming~shortwave__energy_flux",
90 "land_surface_wind__x_component_of_velocity",
91 "land_surface_wind__y_component_of_velocity",
92 ];
95
96 let output_vars = vec![
97 "land_surface_water__runoff_volume_flux",
98 "land_surface_water__runoff_depth",
99 ];
100
101 let mut variables = HashMap::new();
103 for var in &input_vars {
104 variables.insert(var.to_string(), vec![0.0]);
105 }
106 variables.insert("basin__mean_of_elevation".to_string(), vec![0.0]);
108 variables.insert("basin__mean_of_slope".to_string(), vec![0.0]);
109
110 for var in &output_vars {
111 variables.insert(var.to_string(), vec![0.0]);
112 }
113
114 LstmBmi {
115 models: Vec::new(),
116 device,
117 config_path: String::new(),
118 area_sqkm: 0.0,
119 output_scale_factor_cms: 0.0,
120 variables,
121 input_var_names: input_vars,
122 output_var_names: output_vars,
123 current_time: 0.0,
124 start_time: 0.0,
125 end_time: 36000.0, time_step: 3600.0, }
128 }
129
130 fn internal_to_external_name(&self, internal: &str) -> String {
131 let mapping = [
132 (
133 "DLWRF_surface",
134 "land_surface_radiation~incoming~longwave__energy_flux",
135 ),
136 ("PRES_surface", "land_surface_air__pressure"),
137 (
138 "SPFH_2maboveground",
139 "atmosphere_air_water~vapor__relative_saturation",
140 ),
141 (
142 "APCP_surface",
143 "atmosphere_water__liquid_equivalent_precipitation_rate",
144 ),
145 (
146 "DSWRF_surface",
147 "land_surface_radiation~incoming~shortwave__energy_flux",
148 ),
149 ("TMP_2maboveground", "land_surface_air__temperature"),
150 (
151 "UGRD_10maboveground",
152 "land_surface_wind__x_component_of_velocity",
153 ),
154 (
155 "VGRD_10maboveground",
156 "land_surface_wind__y_component_of_velocity",
157 ),
158 ("elev_mean", "basin__mean_of_elevation"),
159 ("slope_mean", "basin__mean_of_slope"),
160 ];
161
162 mapping
163 .iter()
164 .find(|(k, _)| *k == internal)
165 .map(|(_, v)| v.to_string())
166 .unwrap_or_else(|| internal.to_string())
167 }
168
169 fn load_single_model(&self, training_config_path: &Path) -> BmiResult<ModelInstance<B>> {
170 let training_config = fs::read_to_string(training_config_path)?;
171 let training_config: serde_yaml::Value = serde_yaml::from_str(&training_config)?;
172
173 let model_dir = training_config["run_dir"]
175 .as_str()
176 .ok_or("Missing run_dir")?
177 .replace(
178 "..",
179 training_config_path
180 .parent()
181 .unwrap()
182 .parent()
183 .unwrap()
184 .parent()
185 .unwrap()
186 .to_str()
187 .unwrap(),
188 );
189
190 let model_path = glob(&format!("{}/model_*.pt", model_dir))?
191 .next()
192 .ok_or("No model file found")??;
193
194 let model_folder = model_path.parent().unwrap();
196 let burn_dir = model_folder.join("burn");
197 let converted_path = burn_dir.join(model_path.file_name().unwrap());
198 let lock_file_path = burn_dir.join(".conversion.lock");
199
200 if !burn_dir.exists() {
202 fs::create_dir_all(&burn_dir)?;
203 }
204
205 let needs_conversion = !converted_path.exists()
207 || !converted_path.with_extension("json").exists()
208 || !burn_dir.join("train_data_scaler.json").exists()
209 || !burn_dir.join("weights.json").exists();
210
211 if needs_conversion {
212 let mut lock_acquired = false;
214 let process_id = std::process::id();
215
216 loop {
217 match fs::OpenOptions::new()
219 .write(true)
220 .create_new(true)
221 .open(&lock_file_path)
222 {
223 Ok(mut file) => {
224 use std::io::Write;
226 writeln!(file, "Locked by process {}", process_id)?;
227 lock_acquired = true;
228 break;
230 }
231 Err(_) => {
232 std::thread::sleep(std::time::Duration::from_millis(1000));
238
239 if converted_path.exists()
241 && converted_path.with_extension("json").exists()
242 && burn_dir.join("train_data_scaler.json").exists()
243 && burn_dir.join("weights.json").exists()
244 {
245 break;
247 }
248 }
249 }
250 }
251
252 if lock_acquired {
254 println!(
255 "Process {} converting PyTorch weights to Burn format for model: {}",
256 process_id,
257 model_path.display()
258 );
259
260 match convert_model(&model_path, &training_config_path) {
262 Ok(_) => {
263 println!("Process {} completed model conversion", process_id);
264 }
265 Err(e) => {
266 let _ = fs::remove_file(&lock_file_path);
268 return Err(Box::new(std::io::Error::new(
269 std::io::ErrorKind::Other,
270 format!("Model conversion failed: {}", e),
271 )));
272 }
273 }
274
275 fs::remove_file(&lock_file_path)?;
277 println!("Process {} released conversion lock", process_id);
278 }
279 } else {
280 }
282
283 let metadata_str = fs::read_to_string(converted_path.with_extension("json"))?;
285 let metadata: ModelMetadata = serde_json::from_str(&metadata_str)?;
286
287 let scalars_str = fs::read_to_string(burn_dir.join("train_data_scaler.json"))?;
289 let scalars: TrainingScalars = serde_json::from_str(&scalars_str)?;
290
291 let record = PyTorchFileRecorder::<FullPrecisionSettings>::default()
293 .load(converted_path.into(), &self.device)?;
294
295 let mut model = NextgenLstm::init(
296 &self.device,
297 metadata.input_size,
298 metadata.hidden_size,
299 metadata.output_size,
300 );
301 model = model.load_record(record);
302 model.load_json_weights(
303 &self.device,
304 burn_dir.join("weights.json").to_str().unwrap(),
305 );
306
307 Ok(ModelInstance {
308 model,
309 metadata,
310 scalars,
311 lstm_state: None,
312 })
313 }
314
315 fn run_single_model(&mut self, model_idx: usize, inputs: &[f32]) -> BmiResult<f32> {
316 let model_instance = &mut self.models[model_idx];
317
318 let scaled_inputs: Vec<f32> = inputs
320 .iter()
321 .zip(&model_instance.scalars.input_mean)
322 .zip(&model_instance.scalars.input_std)
323 .map(
324 |((val, mean), std)| {
325 if *std != 0.0 { (val - mean) / std } else { 0.0 }
326 },
327 )
328 .collect();
329
330 let input_tensor_data = vec_to_tensor(
332 &scaled_inputs,
333 vec![1, 1, model_instance.metadata.input_size],
334 );
335 let input_tensor = Tensor::from_data(input_tensor_data, &self.device);
336
337 let (output, new_state) = model_instance
339 .model
340 .forward(input_tensor, model_instance.lstm_state.take());
341 model_instance.lstm_state = Some(new_state);
342
343 let output_vec: Vec<f32> = output.into_data().to_vec().unwrap();
345 let output_value = output_vec[0];
346
347 let surface_runoff_mm = (output_value * model_instance.scalars.output_std
349 + model_instance.scalars.output_mean)
350 .max(0.0);
351
352 Ok(surface_runoff_mm)
353 }
354
355 fn run_ensemble(&mut self) -> BmiResult<()> {
356 if self.models.is_empty() {
357 return Err("No models in ensemble")?;
358 }
359 let mut ensemble_outputs = Vec::new();
361 for i in 0..self.models.len() {
362 let input_names = self.models[i].metadata.input_names.clone();
363 let mut inputs = Vec::new();
365 for name in &input_names {
366 let bmi_name = self.internal_to_external_name(name);
367 let value = self
368 .variables
369 .get(&bmi_name)
370 .and_then(|v| v.first())
371 .copied()
372 .unwrap_or(0.0) as f32;
373 inputs.push(value);
374 }
375 let output = self.run_single_model(i, &inputs)?;
376 ensemble_outputs.push(output);
377 }
378
379 let mean_surface_runoff_mm = if !ensemble_outputs.is_empty() {
381 ensemble_outputs.iter().sum::<f32>() / ensemble_outputs.len() as f32
382 } else {
383 0.0
384 };
385
386 let surface_runoff_m = mean_surface_runoff_mm / 1000.0;
388 let surface_runoff_volume_m3_s = mean_surface_runoff_mm * self.output_scale_factor_cms;
389
390 self.variables.insert(
392 "land_surface_water__runoff_depth".to_string(),
393 vec![surface_runoff_m as f64],
394 );
395 self.variables.insert(
396 "land_surface_water__runoff_volume_flux".to_string(),
397 vec![surface_runoff_volume_m3_s as f64],
398 );
399 Ok(())
400 }
401}
402
403impl<B: Backend> Bmi for LstmBmi<B> {
404 fn initialize(&mut self, config_file: &str) -> BmiResult<()> {
405 self.config_path = config_file.to_string();
407
408 let config_path = Path::new(config_file);
410 let config_str = fs::read_to_string(config_path)?;
411 let config: serde_yaml::Value = serde_yaml::from_str(&config_str)?;
412
413 let training_configs = config["train_cfg_file"]
415 .as_sequence()
416 .ok_or("train_cfg_file should be an array")?;
417
418 for (idx, config_value) in training_configs.iter().enumerate() {
422 let training_config_path = Path::new(
423 config_value
424 .as_str()
425 .ok_or(format!("train_cfg_file[{}] not a string", idx))?,
426 );
427
428 let model_instance = self.load_single_model(training_config_path)?;
436 self.models.push(model_instance);
437 }
438
439 self.area_sqkm = config
441 .get("area_sqkm")
442 .ok_or("Missing area_sqkm")?
443 .as_f64()
444 .ok_or("area_sqkm not a number")? as f32;
445
446 self.output_scale_factor_cms =
447 (1.0 / 1000.0) * (self.area_sqkm * 1000.0 * 1000.0) * (1.0 / 3600.0);
448
449 let elevation = config
451 .get("elev_mean")
452 .and_then(|v| v.as_f64())
453 .unwrap_or(0.0);
454 let slope = config
455 .get("slope_mean")
456 .and_then(|v| v.as_f64())
457 .unwrap_or(0.0);
458
459 self.variables
460 .insert("basin__mean_of_elevation".to_string(), vec![elevation]);
461 self.variables
462 .insert("basin__mean_of_slope".to_string(), vec![slope]);
463
464 self.current_time = self.start_time;
466
467 Ok(())
472 }
473
474 fn update(&mut self) -> BmiResult<()> {
475 self.run_ensemble()?;
476 self.current_time += self.time_step;
477 Ok(())
478 }
479
480 fn update_until(&mut self, then: f64) -> BmiResult<()> {
481 if then < self.current_time {
482 return Err(Box::new(std::io::Error::new(
483 std::io::ErrorKind::InvalidInput,
484 format!(
485 "Target time {} is before current time {}",
486 then, self.current_time
487 ),
488 )));
489 }
490
491 while self.current_time < then {
492 self.update()?;
493 if self.current_time > then {
494 self.current_time = then;
495 }
496 }
497 Ok(())
498 }
499
500 fn finalize(&mut self) -> BmiResult<()> {
501 self.models.clear();
502 Ok(())
503 }
504
505 fn get_component_name(&self) -> &str {
506 "NextGen LSTM BMI Ensemble"
507 }
508
509 fn get_input_item_count(&self) -> u32 {
510 self.input_var_names.len() as u32
511 }
512
513 fn get_output_item_count(&self) -> u32 {
514 self.output_var_names.len() as u32
515 }
516
517 fn get_input_var_names(&self) -> &[&str] {
518 &self.input_var_names
519 }
520
521 fn get_output_var_names(&self) -> &[&str] {
522 &self.output_var_names
523 }
524
525 fn get_var_grid(&self, name: &str) -> BmiResult<i32> {
526 if self.variables.contains_key(name) {
527 Ok(0) } else {
529 Err(Box::new(std::io::Error::new(
530 std::io::ErrorKind::NotFound,
531 format!("Variable {} not found", name),
532 )))
533 }
534 }
535
536 fn get_var_type(&self, name: &str) -> BmiResult<ValueType> {
537 if self.variables.contains_key(name) {
538 Ok(ValueType::F64)
539 } else {
540 Err(Box::new(std::io::Error::new(
541 std::io::ErrorKind::NotFound,
542 format!("Variable {} not found", name),
543 )))
544 }
545 }
546
547 fn get_var_units(&self, name: &str) -> BmiResult<&str> {
548 match_var!(name,
549 "atmosphere_water__liquid_equivalent_precipitation_rate" => Ok("mm h-1"),
550 "land_surface_air__temperature" => Ok("degK"),
551 "land_surface_radiation~incoming~longwave__energy_flux" => Ok("W m-2"),
552 "land_surface_air__pressure" => Ok("Pa"),
553 "atmosphere_air_water~vapor__relative_saturation" => Ok("kg kg-1"),
554 "land_surface_radiation~incoming~shortwave__energy_flux" => Ok("W m-2"),
555 "land_surface_wind__x_component_of_velocity" => Ok("m s-1"),
556 "land_surface_wind__y_component_of_velocity" => Ok("m s-1"),
557 "basin__mean_of_elevation" => Ok("m"),
558 "basin__mean_of_slope" => Ok("m km-1"),
559 "land_surface_water__runoff_volume_flux" => Ok("m3 s-1"),
560 "land_surface_water__runoff_depth" => Ok("m")
561 )
562 }
563
564 fn get_var_nbytes(&self, name: &str) -> BmiResult<u32> {
565 let itemsize = self.get_var_itemsize(name)?;
566 let values = self.get_value_ptr(name)?;
567 Ok(values.len() as u32 * itemsize)
568 }
569
570 fn get_var_location(&self, name: &str) -> BmiResult<Location> {
571 if self.variables.contains_key(name) {
572 Ok(Location::Node)
573 } else {
574 Err(Box::new(std::io::Error::new(
575 std::io::ErrorKind::NotFound,
576 format!("Variable {} not found", name),
577 )))
578 }
579 }
580
581 fn get_current_time(&self) -> f64 {
582 self.current_time
583 }
584
585 fn get_start_time(&self) -> f64 {
586 self.start_time
587 }
588
589 fn get_end_time(&self) -> f64 {
590 self.end_time
591 }
592
593 fn get_time_units(&self) -> &str {
594 "seconds"
595 }
596
597 fn get_time_step(&self) -> f64 {
598 self.time_step
599 }
600
601 fn get_value_ptr(&self, name: &str) -> BmiResult<RefValues<'_>> {
602 Ok(self
603 .variables
604 .get(name)
605 .map(|v| RefValues::F64(v))
606 .ok_or_else(|| {
607 Box::new(std::io::Error::new(
608 std::io::ErrorKind::NotFound,
609 format!("Variable {} not found", name),
610 ))
611 })?)
612 }
613
614 fn get_value_at_indices(&self, name: &str, inds: &[u32]) -> BmiResult<Values> {
615 let values = self.variables.get(name).ok_or_else(|| {
616 std::io::Error::new(
617 std::io::ErrorKind::NotFound,
618 format!("Variable {} not found", name),
619 )
620 })?;
621
622 let mut result = Vec::with_capacity(inds.len());
623 for &idx in inds {
624 if (idx as usize) >= values.len() {
625 return Err(Box::new(std::io::Error::new(
626 std::io::ErrorKind::InvalidInput,
627 format!("Index {} out of bounds", idx),
628 )));
629 }
630 result.push(values[idx as usize]);
631 }
632 Ok(Values::F64(result))
633 }
634
635 fn set_value(&mut self, name: &str, src: RefValues) -> BmiResult<()> {
636 if let RefValues::F64(values) = src {
637 if let Some(var) = self.variables.get_mut(name) {
638 if values.len() != var.len() {
639 return Err(Box::new(std::io::Error::new(
640 std::io::ErrorKind::InvalidInput,
641 "Source array size mismatch",
642 )));
643 }
644 var.copy_from_slice(values);
645 Ok(())
646 } else {
647 Err(Box::new(std::io::Error::new(
648 std::io::ErrorKind::NotFound,
649 format!("Variable {} not found", name),
650 )))
651 }
652 } else {
653 Err(Box::new(std::io::Error::new(
654 std::io::ErrorKind::InvalidInput,
655 "Type mismatch: expected F64",
656 )))
657 }
658 }
659
660 fn set_value_at_indices(&mut self, name: &str, inds: &[u32], src: RefValues) -> BmiResult<()> {
661 if let RefValues::F64(values) = src {
662 if values.len() != inds.len() {
663 return Err(Box::new(std::io::Error::new(
664 std::io::ErrorKind::InvalidInput,
665 "Source array size doesn't match indices count",
666 )));
667 }
668
669 let var = self.variables.get_mut(name).ok_or_else(|| {
670 std::io::Error::new(
671 std::io::ErrorKind::NotFound,
672 format!("Variable {} not found", name),
673 )
674 })?;
675
676 for (i, &idx) in inds.iter().enumerate() {
677 if (idx as usize) >= var.len() {
678 return Err(Box::new(std::io::Error::new(
679 std::io::ErrorKind::InvalidInput,
680 format!("Index {} out of bounds", idx),
681 )));
682 }
683 var[idx as usize] = values[i];
684 }
685 Ok(())
686 } else {
687 Err(Box::new(std::io::Error::new(
688 std::io::ErrorKind::InvalidInput,
689 "Type mismatch: expected F64",
690 )))
691 }
692 }
693}
694
695#[unsafe(no_mangle)]
697pub extern "C" fn register_bmi_lstm(handle: *mut ffi::Bmi) -> *mut ffi::Bmi {
698 type Backend = burn::backend::Candle;
700 let device = Default::default();
701
702 let model = LstmBmi::<Backend>::new(device);
703 register_model(handle, model);
704 handle
705}