use ndarray::{Array, IxDyn};
use scirs2_neural::layers::{
Embedding, EmbeddingConfig, Layer, PatchEmbedding, PositionalEmbedding,
};
#[allow(dead_code)]
fn main() -> Result<(), Box<dyn std::error::Error>> {
println!("Running embedding examples...");
println!("\n--- Basic Embedding Example ---");
let config = EmbeddingConfig {
num_embeddings: 10,
embedding_dim: 5,
padding_idx: Some(0),
max_norm: None,
norm_type: 2.0,
scale_grad_by_freq: false,
sparse: false,
};
let embedding = Embedding::<f32>::new(config)?;
let indices = Array::from_shape_vec(IxDyn(&[2, 3]), vec![1, 2, 0, 3, 0, 4])?;
let output = embedding.forward(&indices.mapv(|x| x as f32))?;
println!("Input indices shape: {:?}", indices.shape());
println!("Output embeddings shape: {:?}", output.shape());
println!(
"First embedding vector: {:?}",
output.slice(ndarray::s![0, 0, ..]).to_owned()
);
println!("\n--- Positional Embedding Example ---");
let pos_embedding = PositionalEmbedding::<f32>::new(10, 8, false)?;
let token_embeddings = Array::from_shape_fn(IxDyn(&[2, 5, 8]), |_| 1.0f32);
let output = pos_embedding.forward(&token_embeddings)?;
println!("Token embeddings shape: {:?}", token_embeddings.shape());
"Output with positional encoding shape: {:?}",
output.shape()
"First token before positional encoding: {:?}",
token_embeddings.slice(ndarray::s![0, 0, ..]).to_owned()
"First token after positional encoding: {:?}",
println!("\n--- Patch Embedding Example ---");
let patch_embedding = PatchEmbedding::<f32>::new((32, 32), (8, 8), 3, 96, true)?;
let image_input = Array::from_shape_fn(IxDyn(&[1, 3, 32, 32]), |_| rand::random::<f32>());
let output = patch_embedding.forward(&image_input)?;
println!("Input image shape: {:?}", image_input.shape());
println!("Patch embeddings shape: {:?}", output.shape());
println!("Number of patches: {}", patch_embedding.num_patches());
println!("Embedding dimension: {}", patch_embedding.embedding_dim);
"First patch embedding (first 5 values): {:?}",
output.slice(ndarray::s![0, 0, ..5]).to_owned()
println!("\nAll embedding examples completed successfully!");
Ok(())
}