use pyo3::exceptions::{PyNotImplementedError, PyTypeError, PyValueError};
use pyo3::types::PyBytes;
use pyo3::prelude::*;
use planetarium::{EncoderError, Matrix, Matrix23, Pixel, Point, Vector};
use planetarium::{
Canvas as RsCanvas, ImageFormat as RsImageFormat, SpotId as RsSpotId, SpotShape as RsSpotShape,
Transform as RsTransform, Window as RsWindow,
};
#[pyclass(module = "pyplanetarium", frozen, freelist = 8)]
struct SpotShape(RsSpotShape);
#[pyclass(module = "pyplanetarium", frozen, freelist = 8)]
struct SpotId(RsSpotId);
#[pyclass(module = "pyplanetarium", frozen, freelist = 8)]
struct Transform(RsTransform);
#[pyclass(module = "pyplanetarium", frozen, freelist = 8)]
struct Window(RsWindow);
#[pyclass(module = "pyplanetarium", frozen, freelist = 8)]
struct ImageFormat(RsImageFormat);
#[pyclass(module = "pyplanetarium")]
struct Canvas(RsCanvas);
#[pymethods]
impl SpotShape {
#[new]
fn new(src: Option<&PyAny>) -> PyResult<Self> {
if let Some(src) = src {
if let Ok(k) = src.extract::<f32>() {
Ok(SpotShape(k.into()))
} else if let Ok(kxy) = src.extract::<(f32, f32)>() {
Ok(SpotShape(kxy.into()))
} else if let Ok(mat) = src.extract::<Matrix>() {
Ok(SpotShape(mat.into()))
} else {
Err(PyTypeError::new_err(format!(
"Unexpected initializer type: '{}'",
src.get_type().name().unwrap()
)))
}
} else {
Ok(SpotShape(RsSpotShape::default()))
}
}
fn scale(&self, k: f32) -> SpotShape {
SpotShape(self.0.scale(k))
}
fn stretch(&self, kx: f32, ky: f32) -> SpotShape {
SpotShape(self.0.stretch(kx, ky))
}
fn rotate(&self, phi: f32) -> SpotShape {
SpotShape(self.0.rotate(phi))
}
fn __str__(&self) -> String {
self.0.to_string()
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}
#[pymethods]
impl SpotId {
fn __repr__(&self) -> String {
format!("SpotId({})", self.0)
}
fn __hash__(&self) -> usize {
self.0
}
}
#[pymethods]
impl Transform {
#[new]
fn new(src: Option<&PyAny>) -> PyResult<Self> {
if let Some(src) = src {
if let Ok(k) = src.extract::<f32>() {
Ok(Transform(k.into()))
} else if let Ok(shift) = src.extract::<Vector>() {
Ok(Transform(shift.into()))
} else if let Ok(mat) = src.extract::<Matrix>() {
Ok(Transform(mat.into()))
} else if let Ok(mat) = src.extract::<Matrix23>() {
Ok(Transform(mat.into()))
} else {
Err(PyTypeError::new_err(format!(
"Unexpected initializer type: '{}'",
src.get_type().name().unwrap()
)))
}
} else {
Ok(Transform(RsTransform::default()))
}
}
fn translate(&self, shift: Vector) -> Transform {
Transform(self.0.translate(shift))
}
fn scale(&self, k: f32) -> Transform {
Transform(self.0.scale(k))
}
fn stretch(&self, kx: f32, ky: f32) -> Transform {
Transform(self.0.stretch(kx, ky))
}
fn rotate(&self, phi: f32) -> Transform {
Transform(self.0.rotate(phi))
}
fn compose(&self, t: &Transform) -> Transform {
Transform(self.0.compose(t.0))
}
fn __str__(&self) -> String {
self.0.to_string()
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}
#[pymethods]
impl Window {
#[new]
fn new_(src: ((u32, u32), (u32, u32))) -> Self {
Window(src.into())
}
#[staticmethod]
fn new(width: u32, height: u32) -> Self {
Window(RsWindow::new(width, height))
}
fn at(&self, x: u32, y: u32) -> Window {
Window(self.0.at(x, y))
}
fn __str__(&self) -> String {
self.0.to_string()
}
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
}
#[allow(non_upper_case_globals)]
#[pymethods]
impl ImageFormat {
#[classattr]
const RawGamma8Bpp: ImageFormat = ImageFormat(RsImageFormat::RawGamma8Bpp);
#[classattr]
const RawLinear10BppLE: ImageFormat = ImageFormat(RsImageFormat::RawLinear10BppLE);
#[classattr]
const RawLinear12BppLE: ImageFormat = ImageFormat(RsImageFormat::RawLinear12BppLE);
#[classattr]
const PngGamma8Bpp: ImageFormat = ImageFormat(RsImageFormat::PngGamma8Bpp);
#[classattr]
const PngLinear16Bpp: ImageFormat = ImageFormat(RsImageFormat::PngLinear16Bpp);
fn __repr__(&self) -> String {
format!("{:?}", self.0)
}
fn __hash__(&self) -> usize {
self.0 as usize
}
}
fn my_to_pyerr(err: EncoderError) -> PyErr {
match err {
EncoderError::BrokenWindow => PyValueError::new_err("window is out of bounds".to_string()),
EncoderError::InvalidSubsamplingRate => {
PyValueError::new_err("bad subsampling factors".to_string())
}
_ => PyNotImplementedError::new_err(err.to_string()),
}
}
#[pymethods]
impl Canvas {
#[classattr]
const PIXEL_MAX: Pixel = Pixel::MAX;
#[staticmethod]
fn new(width: u32, height: u32) -> Self {
Canvas(RsCanvas::new(width, height))
}
fn add_spot(&mut self, position: Point, shape: &SpotShape, intensity: f32) -> SpotId {
let id = self.0.add_spot(position, shape.0, intensity);
SpotId(id)
}
fn spot_position(&self, spot: &SpotId) -> Option<Point> {
self.0.spot_position(spot.0)
}
fn spot_intensity(&self, spot: &SpotId) -> Option<f32> {
self.0.spot_intensity(spot.0)
}
fn set_spot_offset(&mut self, spot: &SpotId, offset: Vector) {
self.0.set_spot_offset(spot.0, offset);
}
fn set_spot_illumination(&mut self, spot: &SpotId, illumination: f32) {
self.0.set_spot_illumination(spot.0, illumination);
}
fn clear(&mut self) {
self.0.clear();
}
fn draw(&mut self) {
self.0.draw();
}
fn dimensions(&self) -> (u32, u32) {
self.0.dimensions()
}
fn set_background(&mut self, level: Pixel) {
self.0.set_background(level);
}
fn set_view_transform(&mut self, transform: &Transform) {
self.0.set_view_transform(transform.0);
}
fn set_brightness(&mut self, brightness: f32) {
self.0.set_brightness(brightness);
}
fn export_image(&self, format: &ImageFormat, py: Python) -> PyResult<Py<PyBytes>> {
match self.0.export_image(format.0) {
Ok(b) => Ok(PyBytes::new(py, b.as_slice()).into()),
Err(e) => Err(my_to_pyerr(e)),
}
}
fn export_window_image(
&self,
window: &Window,
format: &ImageFormat,
py: Python,
) -> PyResult<Py<PyBytes>> {
match self.0.export_window_image(window.0, format.0) {
Ok(b) => Ok(PyBytes::new(py, b.as_slice()).into()),
Err(e) => Err(my_to_pyerr(e)),
}
}
fn export_subsampled_image(
&self,
factors: (u32, u32),
format: &ImageFormat,
py: Python,
) -> PyResult<Py<PyBytes>> {
match self.0.export_subsampled_image(factors, format.0) {
Ok(b) => Ok(PyBytes::new(py, b.as_slice()).into()),
Err(e) => Err(my_to_pyerr(e)),
}
}
fn __repr__(&self) -> String {
let (w, h) = self.0.dimensions();
format!("Canvas({w}, {h})")
}
}
#[pymodule]
fn pyplanetarium(_py: Python, m: &PyModule) -> PyResult<()> {
m.setattr("__version__", env!("CARGO_PKG_VERSION"))?;
m.setattr("__author__", env!("CARGO_PKG_AUTHORS"))?;
m.add_class::<SpotShape>()?;
m.add_class::<SpotId>()?;
m.add_class::<Transform>()?;
m.add_class::<Window>()?;
m.add_class::<ImageFormat>()?;
m.add_class::<Canvas>()?;
Ok(())
}