#![allow(clippy::doc_markdown)]
use std::io::Cursor;
use bytes::Bytes;
use dlpark::SafeManagedTensorVersioned;
use dlpark::ffi::Device;
use ndarray::Array3;
use numpy::{PyArray1, PyArray3, ToPyArray};
use object_store::{ObjectStore, parse_url};
use pyo3::exceptions::{PyBufferError, PyFileNotFoundError, PyNotImplementedError, PyValueError};
use pyo3::prelude::{PyModule, PyResult, Python, pyclass, pyfunction, pymethods, pymodule};
use pyo3::types::PyModuleMethods;
use pyo3::{Bound, PyErr, wrap_pyfunction};
use pyo3_stub_gen::define_stub_info_gatherer;
use pyo3_stub_gen_derive::{gen_stub_pyclass, gen_stub_pyfunction, gen_stub_pymethods};
use url::Url;
use crate::io::geotiff::{CogReader, read_geotiff};
#[cfg(all(feature = "cuda", not(doctest)))]
use crate::python::cudacog::PyCudaCogReader;
use crate::traits::Transform;
#[gen_stub_pyclass]
#[pyclass]
#[pyo3(name = "CogReader")]
struct PyCogReader {
inner: CogReader<Cursor<Bytes>>,
}
#[gen_stub_pymethods]
#[pymethods]
impl PyCogReader {
#[new]
fn new(path: &str) -> PyResult<Self> {
let stream: Cursor<Bytes> = path_to_stream(path)?;
let reader =
CogReader::new(stream).map_err(|err| PyValueError::new_err(err.to_string()))?;
Ok(Self { inner: reader })
}
#[gen_stub(override_return_type(type_repr="types.CapsuleType", imports=("types")))]
#[pyo3(signature = (stream=None))]
fn __dlpack__(&mut self, stream: Option<u8>) -> PyResult<SafeManagedTensorVersioned> {
if stream.is_some() {
Err(PyNotImplementedError::new_err(
"stream values other than `None` not supported.",
))
} else {
Ok(())
}?;
let tensor: SafeManagedTensorVersioned = self
.inner
.dlpack()
.map_err(|err| PyValueError::new_err(err.to_string()))?;
Ok(tensor)
}
#[staticmethod]
fn __dlpack_device__() -> (i32, i32) {
let device = Device::CPU;
(device.device_type as i32, device.device_id)
}
#[allow(clippy::type_complexity)]
fn xy_coords<'py>(
&mut self,
py: Python<'py>,
) -> PyResult<(Bound<'py, PyArray1<f64>>, Bound<'py, PyArray1<f64>>)> {
let (x_coords, y_coords) = self
.inner
.xy_coords()
.map_err(|err| PyValueError::new_err(err.to_string()))?;
Ok((x_coords.to_pyarray(py), y_coords.to_pyarray(py)))
}
}
pub(crate) fn path_to_stream(path: &str) -> PyResult<Cursor<Bytes>> {
let file_or_url = match Url::from_file_path(path) {
Ok(filepath) => filepath,
Err(()) => Url::parse(path)
.map_err(|_| PyValueError::new_err(format!("Cannot parse path: {path}")))?,
};
let (store, location) = parse_url(&file_or_url)
.map_err(|_| PyValueError::new_err(format!("Cannot parse url: {file_or_url}")))?;
let runtime = tokio::runtime::Builder::new_multi_thread()
.enable_all()
.build()?;
let stream = runtime.block_on(async {
let result = store
.get(&location)
.await
.map_err(|_| PyFileNotFoundError::new_err(format!("Cannot find file: {path}")))?;
let bytes = result.bytes().await.map_err(|_| {
PyBufferError::new_err(format!("Failed to stream data from {path} into bytes."))
})?;
Ok::<Cursor<Bytes>, PyErr>(Cursor::new(bytes))
})?;
Ok(stream)
}
#[gen_stub_pyfunction]
#[pyfunction]
#[pyo3(name = "read_geotiff")]
fn read_geotiff_py<'py>(path: &str, py: Python<'py>) -> PyResult<Bound<'py, PyArray3<f32>>> {
let stream = path_to_stream(path)?;
let array: Array3<f32> =
read_geotiff(stream).map_err(|err| PyValueError::new_err(err.to_string()))?;
Ok(array.to_pyarray(py))
}
#[pymodule]
fn cog3pio(_py: Python, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<PyCogReader>()?;
#[cfg(feature = "cuda")]
m.add_class::<PyCudaCogReader>()?;
m.add_function(wrap_pyfunction!(read_geotiff_py, m)?)?;
Ok(())
}
define_stub_info_gatherer!(stub_info);