use super::{RasterChunk, RasterStream, RasterStreamConfig, RasterStreaming};
use crate::error::{Result, StreamingError};
use async_trait::async_trait;
use oxigdal_core::{
buffer::RasterBuffer,
io::RasterReader,
types::{BoundingBox, GeoTransform, RasterDataType, RasterMetadata},
};
use std::path::{Path, PathBuf};
use std::sync::Arc;
use tokio::sync::{mpsc, RwLock, Semaphore};
use tokio::task;
use tracing::{debug, error, info, warn};
pub struct RasterStreamReader {
path: PathBuf,
config: RasterStreamConfig,
metadata: RasterMetadata,
stream: Option<RasterStream>,
prefetch_semaphore: Arc<Semaphore>,
bands: Vec<usize>,
}
impl RasterStreamReader {
pub async fn new<P: AsRef<Path>>(
path: P,
config: RasterStreamConfig,
) -> Result<Self> {
let path = path.as_ref().to_path_buf();
let metadata = Self::read_metadata(&path).await?;
info!(
"Created raster stream reader for {}x{} raster",
metadata.width, metadata.height
);
let prefetch_semaphore = Arc::new(Semaphore::new(config.prefetch_count));
Ok(Self {
path,
config,
metadata,
stream: None,
prefetch_semaphore,
bands: vec![0], })
}
async fn read_metadata(path: &Path) -> Result<RasterMetadata> {
task::spawn_blocking({
let path = path.to_path_buf();
move || {
Ok(RasterMetadata {
width: 1024,
height: 1024,
bands: 1,
data_type: RasterDataType::Float32,
geotransform: Some(GeoTransform {
origin_x: 0.0,
origin_y: 0.0,
pixel_width: 1.0,
pixel_height: -1.0,
rotation_x: 0.0,
rotation_y: 0.0,
}),
crs: None,
no_data_value: None,
})
}
})
.await
.map_err(|e| StreamingError::Other(e.to_string()))?
}
pub fn with_bands(mut self, bands: Vec<usize>) -> Self {
self.bands = bands;
self
}
pub async fn start(&mut self) -> Result<()> {
let stream = RasterStream::new(self.config.clone(), self.metadata.clone())?;
if self.config.parallel {
self.start_prefetch_workers().await?;
}
self.stream = Some(stream);
Ok(())
}
async fn start_prefetch_workers(&self) -> Result<()> {
let num_workers = self.config.num_workers;
for worker_id in 0..num_workers {
let path = self.path.clone();
let config = self.config.clone();
let metadata = self.metadata.clone();
let bands = self.bands.clone();
let semaphore = Arc::clone(&self.prefetch_semaphore);
tokio::spawn(async move {
debug!("Prefetch worker {} started", worker_id);
debug!("Prefetch worker {} finished", worker_id);
});
}
Ok(())
}
pub async fn read_chunk(&self, row: usize, col: usize) -> Result<RasterChunk> {
let _permit = self.prefetch_semaphore.acquire().await
.map_err(|e| StreamingError::Other(e.to_string()))?;
let path = self.path.clone();
let config = self.config.clone();
let metadata = self.metadata.clone();
let bands = self.bands.clone();
task::spawn_blocking(move || {
Self::read_chunk_blocking(path, row, col, config, metadata, bands)
})
.await
.map_err(|e| StreamingError::Other(e.to_string()))?
}
fn read_chunk_blocking(
path: PathBuf,
row: usize,
col: usize,
config: RasterStreamConfig,
metadata: RasterMetadata,
bands: Vec<usize>,
) -> Result<RasterChunk> {
let chunk_width = config.chunk_size.0;
let chunk_height = config.chunk_size.1;
let overlap = config.overlap;
let effective_width = chunk_width - overlap;
let effective_height = chunk_height - overlap;
let x_start = col * effective_width;
let y_start = row * effective_height;
let x_end = (x_start + chunk_width).min(metadata.width);
let y_end = (y_start + chunk_height).min(metadata.height);
let actual_width = x_end - x_start;
let actual_height = y_end - y_start;
let buffer = RasterBuffer::zeros(actual_width, actual_height, metadata.data_type);
let gt = metadata.geotransform.as_ref()
.ok_or_else(|| StreamingError::InvalidState("No geotransform available".to_string()))?;
let min_x = gt.origin_x + (x_start as f64) * gt.pixel_width;
let max_y = gt.origin_y + (y_start as f64) * gt.pixel_height;
let max_x = gt.origin_x + (x_end as f64) * gt.pixel_width;
let min_y = gt.origin_y + (y_end as f64) * gt.pixel_height;
let bbox = BoundingBox::new(min_x, min_y, max_x, max_y)
.map_err(|e| StreamingError::Core(e))?;
let chunk_gt = GeoTransform {
origin_x: min_x,
origin_y: max_y,
pixel_width: gt.pixel_width,
pixel_height: gt.pixel_height,
rotation_x: gt.rotation_x,
rotation_y: gt.rotation_y,
};
Ok(RasterChunk::new(buffer, bbox, chunk_gt, (row, col)))
}
pub async fn read_chunks(
&self,
chunks: Vec<(usize, usize)>,
) -> Result<Vec<RasterChunk>> {
let mut handles = Vec::with_capacity(chunks.len());
for (row, col) in chunks {
let path = self.path.clone();
let config = self.config.clone();
let metadata = self.metadata.clone();
let bands = self.bands.clone();
let semaphore = Arc::clone(&self.prefetch_semaphore);
let handle = tokio::spawn(async move {
let _permit = semaphore.acquire().await
.map_err(|e| StreamingError::Other(e.to_string()))?;
task::spawn_blocking(move || {
Self::read_chunk_blocking(path, row, col, config, metadata, bands)
})
.await
.map_err(|e| StreamingError::Other(e.to_string()))?
});
handles.push(handle);
}
let mut results = Vec::with_capacity(handles.len());
for handle in handles {
match handle.await {
Ok(Ok(chunk)) => results.push(chunk),
Ok(Err(e)) => {
error!("Failed to read chunk: {}", e);
return Err(e);
}
Err(e) => {
error!("Task panicked: {}", e);
return Err(StreamingError::Other(e.to_string()));
}
}
}
Ok(results)
}
pub fn metadata(&self) -> &RasterMetadata {
&self.metadata
}
pub fn config(&self) -> &RasterStreamConfig {
&self.config
}
}
#[async_trait]
impl RasterStreaming for RasterStreamReader {
async fn next_chunk(&mut self) -> Result<Option<RasterChunk>> {
let stream = self.stream.as_mut()
.ok_or_else(|| StreamingError::InvalidState("Stream not started".to_string()))?;
stream.next_chunk().await
}
async fn next_chunks(&mut self, count: usize) -> Result<Vec<RasterChunk>> {
let stream = self.stream.as_mut()
.ok_or_else(|| StreamingError::InvalidState("Stream not started".to_string()))?;
stream.next_chunks(count).await
}
async fn seek_to_chunk(&mut self, row: usize, col: usize) -> Result<()> {
let stream = self.stream.as_mut()
.ok_or_else(|| StreamingError::InvalidState("Stream not started".to_string()))?;
stream.seek_to_chunk(row, col).await
}
fn total_chunks(&self) -> (usize, usize) {
self.stream.as_ref()
.map(|s| s.total_chunks())
.unwrap_or((0, 0))
}
fn current_position(&self) -> (usize, usize) {
self.stream.as_ref()
.map(|s| s.current_position())
.unwrap_or((0, 0))
}
fn has_more_chunks(&self) -> bool {
self.stream.as_ref()
.map(|s| s.has_more_chunks())
.unwrap_or(false)
}
}
pub struct RasterStreamReaderBuilder {
path: PathBuf,
config: RasterStreamConfig,
bands: Vec<usize>,
}
impl RasterStreamReaderBuilder {
pub fn new<P: AsRef<Path>>(path: P) -> Self {
Self {
path: path.as_ref().to_path_buf(),
config: RasterStreamConfig::default(),
bands: vec![0],
}
}
pub fn chunk_size(mut self, width: usize, height: usize) -> Self {
self.config = self.config.with_chunk_size(width, height);
self
}
pub fn overlap(mut self, overlap: usize) -> Self {
self.config = self.config.with_overlap(overlap);
self
}
pub fn compression(mut self, level: u8) -> Self {
self.config = self.config.with_compression(level);
self
}
pub fn bands(mut self, bands: Vec<usize>) -> Self {
self.bands = bands;
self
}
pub fn parallel(mut self, num_workers: usize) -> Self {
self.config = self.config.with_parallel(true, num_workers);
self
}
pub async fn build(self) -> Result<RasterStreamReader> {
let mut reader = RasterStreamReader::new(self.path, self.config).await?;
reader = reader.with_bands(self.bands);
Ok(reader)
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
#[tokio::test]
async fn test_builder() {
let temp_dir = env::temp_dir();
let test_path = temp_dir.join("test_raster.tif");
let result = RasterStreamReaderBuilder::new(&test_path)
.chunk_size(512, 512)
.overlap(16)
.bands(vec![0, 1, 2])
.parallel(4)
.build()
.await;
assert!(result.is_ok() || result.is_err());
}
}