use ndarray::{Array, IxDyn};
use scirs2_neural::layers::Layer;
use scirs2_neural::models::{BertConfig, BertModel};
#[allow(dead_code)]
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("BERT Model Example");
println!("Creating a small BERT model...");
let config = BertConfig::custom(
10000, 128, 2, 2, );
let model = BertModel::<f32>::new(config)?;
let input = Array::from_shape_fn(
IxDyn(&[2, 16]),
|_| rand::random::<f32>() * 100.0, println!("Input shape: {:?}", input.shape());
let sequence_output = model.forward(&input)?;
println!("Sequence output shape: {:?}", sequence_output.shape());
let pooled_output = model.get_pooled_output(&input)?;
println!("Pooled output shape: {:?}", pooled_output.shape());
println!("\nCreating a BERT-Base model...");
let bert_base = BertModel::<f32>::bert_base_uncased()?;
let base_input = Array::from_shape_fn(
IxDyn(&[1, 64]),
|_| rand::random::<f32>() * 1000.0, println!("BERT-Base input shape: {:?}", base_input.shape());
let base_pooled_output = bert_base.get_pooled_output(&base_input)?;
println!(
"BERT-Base pooled output shape: {:?}",
base_pooled_output.shape()
"BERT-Base hidden dimension: {}",
base_pooled_output.shape()[1]
println!("\nBERT example completed successfully!");
Ok(())
}