1#![allow(clippy::useless_conversion)]
5use std::sync::OnceLock;
23
24use bytes::Bytes;
25use pyo3::create_exception;
26use pyo3::exceptions::{PyIOError, PyRuntimeError, PyValueError};
27use pyo3::prelude::*;
28use pyo3::types::PyBytes;
29use s4_codec_rs::{cpu_gzip, cpu_zstd, ChunkManifest, Codec, CodecError, CodecKind};
30use tokio::runtime::{Builder, Runtime};
31
32create_exception!(s4_codec, S4Error, PyValueError);
53create_exception!(s4_codec, S4CrcMismatchError, S4Error);
54create_exception!(s4_codec, S4SizeMismatchError, S4Error);
55create_exception!(s4_codec, S4CodecMismatchError, S4Error);
56create_exception!(s4_codec, S4UnregisteredCodecError, S4Error);
57create_exception!(s4_codec, S4ManifestSizeExceedsLimitError, S4Error);
58create_exception!(s4_codec, S4ManifestSizeMismatchError, S4Error);
59create_exception!(s4_codec, S4BackendError, PyRuntimeError);
60create_exception!(s4_codec, S4IoError, PyIOError);
61
62fn runtime() -> &'static Runtime {
63 static RT: OnceLock<Runtime> = OnceLock::new();
64 RT.get_or_init(|| {
65 Builder::new_multi_thread()
66 .enable_all()
67 .thread_name("s4-codec-py")
68 .build()
69 .expect("failed to start tokio runtime for s4_codec python binding")
70 })
71}
72
73fn codec_err_to_py(e: CodecError) -> PyErr {
74 use s4_codec_rs::CodecError::*;
75 match e {
76 SizeMismatch { expected, got } => {
77 S4SizeMismatchError::new_err(format!("size mismatch: expected {expected}, got {got}"))
78 }
79 CrcMismatch { expected, got } => S4CrcMismatchError::new_err(format!(
80 "crc32c mismatch: expected {expected:#010x}, got {got:#010x}"
81 )),
82 CodecMismatch { expected, got } => S4CodecMismatchError::new_err(format!(
83 "codec mismatch: expected {expected:?}, got {got:?}"
84 )),
85 UnregisteredCodec(k) => {
86 S4UnregisteredCodecError::new_err(format!("codec {k:?} not registered"))
87 }
88 ManifestSizeExceedsLimit { requested, limit } => S4ManifestSizeExceedsLimitError::new_err(
89 format!("manifest claims {requested} bytes but limit is {limit}"),
90 ),
91 ManifestSizeMismatch { manifest, actual } => S4ManifestSizeMismatchError::new_err(format!(
92 "manifest claims {manifest} bytes but body is {actual}"
93 )),
94 Backend(msg) => S4BackendError::new_err(format!("backend: {msg}")),
95 Io(e) => S4IoError::new_err(format!("io: {e}")),
96 TruncatedStream { expected, got } => S4Error::new_err(format!(
97 "stream truncated: expected {expected} input bytes, got {got}"
98 )),
99 Join(e) => S4BackendError::new_err(format!("backend (worker join): {e}")),
103 }
104}
105
106fn manifest_from_parts(
107 kind: CodecKind,
108 payload_len: u64,
109 original_size: u64,
110 crc32c: u32,
111) -> ChunkManifest {
112 ChunkManifest {
113 codec: kind,
114 original_size,
115 compressed_size: payload_len,
116 crc32c,
117 }
118}
119
120fn block_on<F, T>(py: Python<'_>, fut: F) -> T
124where
125 F: std::future::Future<Output = T> + Send,
126 T: Send,
127{
128 py.allow_threads(|| runtime().block_on(fut))
129}
130
131#[pyclass(name = "CpuZstd", module = "s4_codec")]
134struct PyCpuZstd {
135 inner: cpu_zstd::CpuZstd,
136}
137
138#[pymethods]
139impl PyCpuZstd {
140 #[new]
141 #[pyo3(signature = (level = 3))]
142 fn new(level: i32) -> Self {
143 Self {
144 inner: cpu_zstd::CpuZstd::new(level),
145 }
146 }
147
148 fn compress<'py>(
152 &self,
153 py: Python<'py>,
154 data: &Bound<'py, PyBytes>,
155 ) -> PyResult<(Bound<'py, PyBytes>, u64, u32)> {
156 let input = Bytes::copy_from_slice(data.as_bytes());
157 let codec = self.inner.clone();
158 let (out, manifest) =
159 block_on(py, async move { codec.compress(input).await }).map_err(codec_err_to_py)?;
160 Ok((
161 PyBytes::new(py, &out),
162 manifest.original_size,
163 manifest.crc32c,
164 ))
165 }
166
167 fn decompress<'py>(
170 &self,
171 py: Python<'py>,
172 data: &Bound<'py, PyBytes>,
173 original_size: u64,
174 crc32c: u32,
175 ) -> PyResult<Bound<'py, PyBytes>> {
176 let input = Bytes::copy_from_slice(data.as_bytes());
177 let manifest = manifest_from_parts(
178 CodecKind::CpuZstd,
179 input.len() as u64,
180 original_size,
181 crc32c,
182 );
183 let codec = self.inner.clone();
184 let out = block_on(py, async move { codec.decompress(input, &manifest).await })
185 .map_err(codec_err_to_py)?;
186 Ok(PyBytes::new(py, &out))
187 }
188
189 fn __repr__(&self) -> String {
190 format!("CpuZstd(level={})", cpu_zstd::CpuZstd::DEFAULT_LEVEL)
191 }
192}
193
194#[pyclass(name = "CpuGzip", module = "s4_codec")]
198struct PyCpuGzip {
199 inner: cpu_gzip::CpuGzip,
200}
201
202#[pymethods]
203impl PyCpuGzip {
204 #[new]
205 #[pyo3(signature = (level = 6))]
206 fn new(level: u32) -> Self {
207 Self {
208 inner: cpu_gzip::CpuGzip::new(level),
209 }
210 }
211
212 fn compress<'py>(
213 &self,
214 py: Python<'py>,
215 data: &Bound<'py, PyBytes>,
216 ) -> PyResult<(Bound<'py, PyBytes>, u64, u32)> {
217 let input = Bytes::copy_from_slice(data.as_bytes());
218 let codec = self.inner.clone();
219 let (out, manifest) =
220 block_on(py, async move { codec.compress(input).await }).map_err(codec_err_to_py)?;
221 Ok((
222 PyBytes::new(py, &out),
223 manifest.original_size,
224 manifest.crc32c,
225 ))
226 }
227
228 fn decompress<'py>(
229 &self,
230 py: Python<'py>,
231 data: &Bound<'py, PyBytes>,
232 original_size: u64,
233 crc32c: u32,
234 ) -> PyResult<Bound<'py, PyBytes>> {
235 let input = Bytes::copy_from_slice(data.as_bytes());
236 let manifest = manifest_from_parts(
237 CodecKind::CpuGzip,
238 input.len() as u64,
239 original_size,
240 crc32c,
241 );
242 let codec = self.inner.clone();
243 let out = block_on(py, async move { codec.decompress(input, &manifest).await })
244 .map_err(codec_err_to_py)?;
245 Ok(PyBytes::new(py, &out))
246 }
247
248 fn __repr__(&self) -> String {
249 format!("CpuGzip(level={})", cpu_gzip::CpuGzip::DEFAULT_LEVEL)
250 }
251}
252
253#[pyfunction]
256fn gpu_available() -> bool {
257 s4_codec_rs::nvcomp::is_gpu_available()
258}
259
260#[pymodule]
261fn s4_codec(py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
262 m.add_class::<PyCpuZstd>()?;
263 m.add_class::<PyCpuGzip>()?;
264 m.add_function(wrap_pyfunction!(gpu_available, m)?)?;
265 m.add("__version__", env!("CARGO_PKG_VERSION"))?;
266 m.add("S4Error", py.get_type::<S4Error>())?;
270 m.add("S4CrcMismatchError", py.get_type::<S4CrcMismatchError>())?;
271 m.add("S4SizeMismatchError", py.get_type::<S4SizeMismatchError>())?;
272 m.add(
273 "S4CodecMismatchError",
274 py.get_type::<S4CodecMismatchError>(),
275 )?;
276 m.add(
277 "S4UnregisteredCodecError",
278 py.get_type::<S4UnregisteredCodecError>(),
279 )?;
280 m.add(
281 "S4ManifestSizeExceedsLimitError",
282 py.get_type::<S4ManifestSizeExceedsLimitError>(),
283 )?;
284 m.add(
285 "S4ManifestSizeMismatchError",
286 py.get_type::<S4ManifestSizeMismatchError>(),
287 )?;
288 m.add("S4BackendError", py.get_type::<S4BackendError>())?;
289 m.add("S4IoError", py.get_type::<S4IoError>())?;
290 Ok(())
291}