use crate::error::{RusTorchError, RusTorchResult};
use crate::tensor::Tensor;
use num_traits::Float;
pub struct CommonOps;
impl CommonOps {
pub fn default_all_gather<T: Float + Send + Sync + 'static>(
tensor: &Tensor<T>,
world_size: usize,
) -> RusTorchResult<Vec<Tensor<T>>> {
Ok(vec![tensor.clone(); world_size])
}
pub fn default_gather<T: Float + Send + Sync + 'static>(
tensor: &Tensor<T>,
world_size: usize,
_root: usize,
) -> RusTorchResult<Vec<Tensor<T>>> {
Ok(vec![tensor.clone(); world_size])
}
pub fn default_broadcast<T: Float + Send + Sync + 'static>(
_tensor: &mut Tensor<T>,
_root: usize,
) -> RusTorchResult<()> {
Ok(())
}
pub fn default_all_reduce<T: Float + Send + Sync + 'static>(
_tensor: &mut Tensor<T>,
_op: super::ReduceOp,
) -> RusTorchResult<()> {
Ok(())
}
pub fn validate_tensor<T: Float + 'static>(tensor: &Tensor<T>) -> RusTorchResult<()> {
let shape = tensor.shape();
if shape.is_empty() {
return Err(RusTorchError::CommunicationError(
"Empty tensor shape".to_string(),
));
}
if shape.contains(&0) {
return Err(RusTorchError::ShapeMismatch {
expected: vec![1], actual: shape.to_vec(),
});
}
let total_elements: usize = shape.iter().product();
const MAX_ELEMENTS: usize = 1_000_000_000; if total_elements > MAX_ELEMENTS {
return Err(RusTorchError::CommunicationError(format!(
"Tensor too large: {} elements exceeds maximum {}",
total_elements, MAX_ELEMENTS
)));
}
Ok(())
}
pub fn validate_rank(rank: usize, world_size: usize) -> RusTorchResult<()> {
if rank >= world_size {
return Err(RusTorchError::InvalidRank(format!(
"Invalid rank {} for world size {}",
rank, world_size
)));
}
Ok(())
}
pub fn validate_tensor_shapes<T: Float + 'static>(
tensors: &[Tensor<T>],
expected_shape: &[usize],
) -> RusTorchResult<()> {
for tensor in tensors.iter() {
let actual_shape = tensor.shape();
if actual_shape != expected_shape {
return Err(RusTorchError::ShapeMismatch {
expected: expected_shape.to_vec(),
actual: actual_shape.to_vec(),
});
}
}
Ok(())
}
pub fn unsupported_operation_error(operation: &str, backend: &str) -> RusTorchError {
RusTorchError::backend_unavailable(format!(
"Operation '{}' not supported by backend '{}'",
operation, backend
))
}
}
pub trait BackendOptimizations<T: Float> {
fn enable_gradient_compression(&self) -> bool {
false
}
fn optimal_bucket_size(&self) -> usize {
25 * 1024 * 1024 }
fn supports_async_ops(&self) -> bool {
false
}
fn memory_pool_size(&self) -> usize {
512 * 1024 * 1024 }
fn enable_zero_copy(&self) -> bool {
true
}
fn optimal_stream_count(&self) -> usize {
4
}
fn enable_pipeline_parallelism(&self) -> bool {
true
}
fn optimize_for_communication(&self, tensor: &mut Tensor<T>) -> RusTorchResult<()> {
let _ = tensor;
Ok(())
}
fn get_optimal_chunk_size(&self, tensor_size: usize) -> usize {
(tensor_size / 4).max(1024)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tensor::Tensor;
#[test]
fn test_default_all_gather() {
let tensor: Tensor<f32> = Tensor::ones(&[2, 2]);
let result = CommonOps::default_all_gather(&tensor, 4);
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 4);
}
#[test]
fn test_validate_tensor() {
let valid_tensor: Tensor<f32> = Tensor::ones(&[2, 2]);
assert!(CommonOps::validate_tensor(&valid_tensor).is_ok());
let empty_tensor: Tensor<f32> = Tensor::zeros(&[0]);
assert!(CommonOps::validate_tensor(&empty_tensor).is_err());
}
#[test]
fn test_unsupported_operation_error() {
let error = CommonOps::unsupported_operation_error("test_op", "test_backend");
match error {
RusTorchError::BackendUnavailable { backend: msg } => {
assert!(msg.contains("test_op"));
assert!(msg.contains("test_backend"));
}
_ => panic!("Expected BackendUnavailable error"),
}
}
}