use crate::ir::{Envelope, EnvelopeKind};
use crate::runtime_adapter::{AdapterError, AdapterResult};
use ndarray::{Array, ArrayD, IxDyn};
pub fn envelope_to_tensors(
envelope: &Envelope,
input_shapes: &[Vec<i64>],
input_names: &[String],
) -> AdapterResult<std::collections::HashMap<String, ArrayD<f32>>> {
if input_shapes.is_empty() || input_names.is_empty() {
return Err(AdapterError::InvalidInput(
"No input shapes or names provided".to_string(),
));
}
let input_name = &input_names[0];
let shape_from_metadata: Option<Vec<i64>> =
envelope.metadata.get("tensor_shape").and_then(|s| {
let parts: Result<Vec<i64>, _> = s.split(',').map(|p| p.parse::<i64>()).collect();
parts.ok()
});
let target_shape = match &shape_from_metadata {
Some(shape) => shape.as_slice(),
None => &input_shapes[0],
};
let tensor = match &envelope.kind {
EnvelopeKind::Audio(audio_data) => audio_to_tensor(audio_data, target_shape)?,
EnvelopeKind::Text(text) => text_to_tensor(text, target_shape)?,
EnvelopeKind::Embedding(embedding) => embedding_to_tensor(embedding, target_shape)?,
};
let mut result = std::collections::HashMap::new();
result.insert(input_name.clone(), tensor);
Ok(result)
}
pub fn tensors_to_envelope(
outputs: &std::collections::HashMap<String, ArrayD<f32>>,
output_names: &[String],
) -> AdapterResult<Envelope> {
if outputs.is_empty() {
return Err(AdapterError::InvalidInput(
"No output tensors provided".to_string(),
));
}
let output_name = output_names.first().map(|s| s.as_str()).unwrap_or("output");
let output = outputs
.get(output_name)
.ok_or_else(|| AdapterError::InvalidInput(format!("Output '{}' not found", output_name)))?;
let data = output
.as_slice()
.ok_or_else(|| AdapterError::InvalidInput("Output tensor not contiguous".to_string()))?;
let shape_str = output
.shape()
.iter()
.map(|d| d.to_string())
.collect::<Vec<_>>()
.join(",");
let mut metadata = std::collections::HashMap::new();
metadata.insert("tensor_shape".to_string(), shape_str);
Ok(Envelope::with_metadata(
EnvelopeKind::Embedding(data.to_vec()),
metadata,
))
}
fn audio_to_tensor(audio_data: &[u8], target_shape: &[i64]) -> AdapterResult<ArrayD<f32>> {
let samples = decode_audio_to_samples(audio_data)?;
if samples.is_empty() {
return Err(AdapterError::InvalidInput(
"Audio data is empty".to_string(),
));
}
let final_shape: Vec<usize> = if target_shape.len() == 2 {
let batch = if target_shape[0] == -1 {
1
} else {
target_shape[0] as usize
};
let num_samples = if target_shape[1] == -1 {
samples.len()
} else {
target_shape[1] as usize
};
vec![batch, num_samples]
} else if target_shape.len() == 3 {
let batch = if target_shape[0] == -1 {
1
} else {
target_shape[0] as usize
};
let channels = if target_shape[1] == -1 {
1
} else {
target_shape[1] as usize
};
let num_samples = if target_shape[2] == -1 {
samples.len()
} else {
target_shape[2] as usize
};
vec![batch, channels, num_samples]
} else if target_shape.len() == 1 {
let num_samples = if target_shape[0] == -1 {
samples.len()
} else {
target_shape[0] as usize
};
vec![1, num_samples]
} else {
return Err(AdapterError::InvalidInput(format!(
"Unsupported target shape dimensions: {:?}",
target_shape
)));
};
let expected_size: usize = final_shape.iter().product();
let final_samples = if samples.len() < expected_size {
let mut padded = samples;
padded.resize(expected_size, 0.0);
padded
} else if samples.len() > expected_size {
samples[..expected_size].to_vec()
} else {
samples
};
Array::from_shape_vec(IxDyn(&final_shape), final_samples)
.map_err(|e| AdapterError::RuntimeError(format!("Failed to create audio tensor: {}", e)))
}
fn decode_audio_to_samples(audio_data: &[u8]) -> AdapterResult<Vec<f32>> {
use std::io::Cursor;
let cursor = Cursor::new(audio_data);
match hound::WavReader::new(cursor) {
Ok(mut reader) => {
let spec = reader.spec();
let source_sample_rate = spec.sample_rate;
let source_channels = spec.channels as usize;
let samples: Vec<f32> = match spec.sample_format {
hound::SampleFormat::Float => {
reader.samples::<f32>().filter_map(|s| s.ok()).collect()
}
hound::SampleFormat::Int => {
let bits = spec.bits_per_sample;
let max_value = (1 << (bits - 1)) as f32;
reader
.samples::<i32>()
.filter_map(|s| s.ok())
.map(|s| s as f32 / max_value)
.collect()
}
};
let mono_samples = if source_channels > 1 {
samples
.chunks(source_channels)
.map(|chunk| chunk.iter().sum::<f32>() / source_channels as f32)
.collect()
} else {
samples
};
const TARGET_SAMPLE_RATE: u32 = 16000;
let resampled = if source_sample_rate != TARGET_SAMPLE_RATE {
let ratio = TARGET_SAMPLE_RATE as f32 / source_sample_rate as f32;
let target_len = (mono_samples.len() as f32 * ratio) as usize;
(0..target_len)
.map(|i| {
let source_idx = (i as f32 / ratio) as usize;
mono_samples.get(source_idx).copied().unwrap_or(0.0)
})
.collect()
} else {
mono_samples
};
Ok(resampled)
}
Err(_) => {
if !audio_data.len().is_multiple_of(2) {
return Err(AdapterError::InvalidInput(format!(
"Audio data length ({}) must be even for 16-bit PCM",
audio_data.len()
)));
}
let samples: Vec<f32> = audio_data
.chunks_exact(2)
.map(|chunk| {
let sample = i16::from_le_bytes([chunk[0], chunk[1]]);
sample as f32 / 32768.0
})
.collect();
Ok(samples)
}
}
}
fn text_to_tensor(text: &str, target_shape: &[i64]) -> AdapterResult<ArrayD<f32>> {
let tokens: Vec<i64> = text
.split_whitespace()
.enumerate()
.map(|(i, _)| i as i64)
.collect();
let actual_size = tokens.len();
let has_dynamic = target_shape.iter().any(|&d| d < 0);
if has_dynamic {
let shape: Vec<usize> = if target_shape == [-1] {
vec![actual_size]
} else if target_shape.len() == 2 {
let batch = if target_shape[0] > 0 {
target_shape[0] as usize
} else {
1
};
vec![batch, actual_size]
} else {
vec![actual_size]
};
let tokens_f32: Vec<f32> = tokens.iter().map(|&t| t as f32).collect();
return Array::from_shape_vec(IxDyn(&shape), tokens_f32).map_err(|e| {
AdapterError::RuntimeError(format!("Failed to create text tensor: {}", e))
});
}
let expected_size: i64 = target_shape.iter().product();
let final_tokens = if (actual_size as i64) < expected_size {
let mut padded = tokens;
padded.resize(expected_size as usize, 0);
padded
} else {
tokens[..expected_size as usize].to_vec()
};
let shape: Vec<usize> = target_shape.iter().map(|&s| s as usize).collect();
let tokens_f32: Vec<f32> = final_tokens.iter().map(|&t| t as f32).collect();
Array::from_shape_vec(IxDyn(&shape), tokens_f32)
.map_err(|e| AdapterError::RuntimeError(format!("Failed to create text tensor: {}", e)))
}
fn embedding_to_tensor(embedding: &[f32], target_shape: &[i64]) -> AdapterResult<ArrayD<f32>> {
let actual_size = embedding.len();
let has_dynamic = target_shape.iter().any(|&d| d < 0);
if has_dynamic {
let shape: Vec<usize> = if target_shape == [-1] {
vec![actual_size]
} else if target_shape.len() == 2 {
let batch = if target_shape[0] > 0 {
target_shape[0] as usize
} else {
1
};
let features = if target_shape[1] > 0 {
target_shape[1] as usize
} else {
actual_size / batch
};
vec![batch, features]
} else {
vec![actual_size]
};
return Array::from_shape_vec(IxDyn(&shape), embedding.to_vec()).map_err(|e| {
AdapterError::RuntimeError(format!("Failed to create embedding tensor: {}", e))
});
}
let expected_size: i64 = target_shape.iter().product();
if actual_size as i64 != expected_size {
return Err(AdapterError::InvalidInput(format!(
"Embedding size mismatch: expected {}, got {}",
expected_size, actual_size
)));
}
let shape: Vec<usize> = target_shape.iter().map(|&s| s as usize).collect();
Array::from_shape_vec(IxDyn(&shape), embedding.to_vec()).map_err(|e| {
AdapterError::RuntimeError(format!("Failed to create embedding tensor: {}", e))
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_audio_to_tensor() {
let audio_data = vec![0u8; 32000]; let target_shape = vec![1, 1, 16000];
let result = audio_to_tensor(&audio_data, &target_shape);
assert!(result.is_ok());
}
#[test]
fn test_text_to_tensor() {
let text = "hello world test";
let target_shape = vec![1, 512];
let result = text_to_tensor(text, &target_shape);
assert!(result.is_ok());
}
#[test]
fn test_embedding_to_tensor() {
let embedding = vec![0.1, 0.2, 0.3, 0.4];
let target_shape = vec![1, 4];
let result = embedding_to_tensor(&embedding, &target_shape);
assert!(result.is_ok());
}
#[test]
fn test_envelope_to_tensors_audio() {
let envelope = Envelope::new(EnvelopeKind::Audio(vec![0u8; 32000]));
let input_shapes = vec![vec![1, 1, 16000]];
let input_names = vec!["audio_input".to_string()];
let result = envelope_to_tensors(&envelope, &input_shapes, &input_names);
assert!(result.is_ok());
let tensors = result.unwrap();
assert!(tensors.contains_key("audio_input"));
}
#[test]
fn test_envelope_to_tensors_text() {
let envelope = Envelope::new(EnvelopeKind::Text("hello world".to_string()));
let input_shapes = vec![vec![1, 512]];
let input_names = vec!["text_input".to_string()];
let result = envelope_to_tensors(&envelope, &input_shapes, &input_names);
assert!(result.is_ok());
let tensors = result.unwrap();
assert!(tensors.contains_key("text_input"));
}
}