use scirs2_core::random::Random;
use scirs2_core::RngExt;
use torsh_core::{Result as TorshResult, TorshError};
use torsh_tensor::Tensor;
pub fn multinomial(
input: &Tensor,
num_samples: usize,
replacement: bool,
generator: Option<u64>,
) -> TorshResult<Tensor> {
if input.ndim() > 2 {
return Err(TorshError::dimension_error(
"input must be 1D or 2D",
"multinomial",
));
}
let is_1d = input.ndim() == 1;
let input_2d = if is_1d {
input.view(&[1, -1])?
} else {
input.clone()
};
let (num_rows, num_cols) = (input_2d.shape().dims()[0], input_2d.shape().dims()[1]);
if !replacement && num_samples > num_cols {
return Err(TorshError::InvalidArgument(format!(
"Cannot sample {} samples without replacement from {} categories",
num_samples, num_cols
)));
}
let mut rng = if let Some(seed) = generator {
Random::seed(seed)
} else {
Random::seed(42) };
let mut output_data = Vec::with_capacity(num_rows * num_samples);
for row in 0..num_rows {
let row_start = row * num_cols;
let row_end = row_start + num_cols;
let data = input_2d.data()?;
let row_probs: Vec<f32> = data[row_start..row_end].to_vec();
let sum: f32 = row_probs.iter().sum();
if sum <= 0.0 {
return Err(TorshError::InvalidArgument(
"multinomial: all probabilities are zero".to_string(),
));
}
let normalized_probs: Vec<f32> = row_probs.iter().map(|&p| p / sum).collect();
if replacement {
let mut cumulative: Vec<f32> = Vec::with_capacity(normalized_probs.len());
let mut sum = 0.0;
for &prob in &normalized_probs {
sum += prob;
cumulative.push(sum);
}
for _ in 0..num_samples {
let r: f32 = rng.gen_range(0.0..1.0);
let idx = cumulative
.iter()
.position(|&cum_prob| r <= cum_prob)
.unwrap_or(cumulative.len() - 1);
output_data.push(idx as f32);
}
} else {
let mut indices: Vec<usize> = (0..num_cols).collect();
let mut remaining_probs = normalized_probs.clone();
for _ in 0..num_samples {
let sum: f32 = remaining_probs.iter().sum();
if sum <= 0.0 {
return Err(TorshError::InvalidArgument(
"multinomial: all remaining probabilities are zero".to_string(),
));
}
let normalized: Vec<f32> = remaining_probs.iter().map(|&p| p / sum).collect();
let mut cumulative: Vec<f32> = Vec::with_capacity(normalized.len());
let mut sum = 0.0;
for &prob in &normalized {
sum += prob;
cumulative.push(sum);
}
let rand_val: f32 = rng.gen_range(0.0..1.0);
let idx = cumulative
.iter()
.position(|&x| x >= rand_val)
.unwrap_or(cumulative.len() - 1);
output_data.push(indices[idx] as f32);
indices.remove(idx);
remaining_probs.remove(idx);
}
}
}
let output_shape = if is_1d {
vec![num_samples]
} else {
vec![num_rows, num_samples]
};
Tensor::from_vec(output_data, &output_shape)
}
pub fn bernoulli_(shape: &[usize], p: f32, generator: Option<u64>) -> TorshResult<Tensor> {
if !(0.0..=1.0).contains(&p) {
return Err(TorshError::InvalidArgument(
"bernoulli_: p must be between 0 and 1".to_string(),
));
}
let mut rng = if let Some(seed) = generator {
Random::seed(seed)
} else {
Random::seed(42) };
let size: usize = shape.iter().product();
let mut values = Vec::with_capacity(size);
for _ in 0..size {
let val: f32 = if rng.random::<f32>() < p { 1.0 } else { 0.0 };
values.push(val);
}
Tensor::from_vec(values, shape)
}
pub fn bernoulli(input: &Tensor, generator: Option<u64>) -> TorshResult<Tensor> {
let mut rng = if let Some(seed) = generator {
Random::seed(seed)
} else {
Random::seed(42) };
let data = input.data()?;
let values: Vec<f32> = data
.iter()
.map(|&p| {
if !(0.0..=1.0).contains(&p) {
panic!("bernoulli: all values in input must be between 0 and 1");
}
if rng.random::<f32>() < p {
1.0
} else {
0.0
}
})
.collect();
Tensor::from_vec(values, &input.shape().dims().to_vec())
}