use std::fmt::Write;
use std::os::raw::c_int;
use std::ptr::NonNull;
use bytes::{Bytes, BytesMut};
use pyo3::buffer::PyBuffer;
use pyo3::exceptions::{PyIndexError, PyValueError};
use pyo3::prelude::*;
use pyo3::types::{PyDict, PySlice, PyTuple};
use pyo3::{ffi, IntoPyObjectExt};
#[pyclass(name = "Bytes", subclass, frozen, sequence, weakref)]
#[derive(Hash, PartialEq, PartialOrd, Eq, Ord)]
pub struct PyBytes(Bytes);
impl AsRef<Bytes> for PyBytes {
fn as_ref(&self) -> &Bytes {
&self.0
}
}
impl AsRef<[u8]> for PyBytes {
fn as_ref(&self) -> &[u8] {
self.0.as_ref()
}
}
impl PyBytes {
pub fn new(buffer: Bytes) -> Self {
Self(buffer)
}
pub fn into_inner(self) -> Bytes {
self.0
}
pub fn as_slice(&self) -> &[u8] {
self.as_ref()
}
fn slice(&self, slice: &Bound<'_, PySlice>) -> PyResult<PyBytes> {
let bytes_length = self.0.len() as isize;
let (start, stop, step) = {
let slice_indices = slice.indices(bytes_length)?;
(slice_indices.start, slice_indices.stop, slice_indices.step)
};
let new_capacity = if (step > 0 && stop > start) || (step < 0 && stop < start) {
(((stop - start).abs() + step.abs() - 1) / step.abs()) as usize
} else {
0
};
if new_capacity == 0 {
return Ok(PyBytes(Bytes::new()));
}
if step == 1 {
if start < 0 && stop >= bytes_length {
let out = self.0.slice(..);
let py_bytes = PyBytes(out);
return Ok(py_bytes);
}
if start >= 0 && stop <= bytes_length && start < stop {
let out = self.0.slice(start as usize..stop as usize);
let py_bytes = PyBytes(out);
return Ok(py_bytes);
}
}
if step > 0 {
let mut new_buf = BytesMut::with_capacity(new_capacity);
new_buf.extend(
(start..stop)
.step_by(step as usize)
.map(|i| self.0[i as usize]),
);
Ok(PyBytes(new_buf.freeze()))
} else {
let mut new_buf = BytesMut::with_capacity(new_capacity);
new_buf.extend(
(stop + 1..=start)
.rev()
.step_by((-step) as usize)
.map(|i| self.0[i as usize]),
);
Ok(PyBytes(new_buf.freeze()))
}
}
}
impl From<PyBytes> for Bytes {
fn from(value: PyBytes) -> Self {
value.0
}
}
impl From<Vec<u8>> for PyBytes {
fn from(value: Vec<u8>) -> Self {
PyBytes(value.into())
}
}
impl From<Bytes> for PyBytes {
fn from(value: Bytes) -> Self {
PyBytes(value)
}
}
impl From<BytesMut> for PyBytes {
fn from(value: BytesMut) -> Self {
PyBytes(value.into())
}
}
#[pymethods]
impl PyBytes {
#[new]
#[pyo3(signature = (buf = PyBytes(Bytes::new())), text_signature = "(buf = b'')")]
fn py_new(buf: PyBytes) -> Self {
buf
}
fn __getnewargs_ex__(&self, py: Python) -> PyResult<PyObject> {
let py_bytes = self.to_bytes(py);
let args = PyTuple::new(py, vec![py_bytes])?.into_py_any(py)?;
let kwargs = PyDict::new(py);
PyTuple::new(py, [args, kwargs.into_py_any(py)?])?.into_py_any(py)
}
fn __len__(&self) -> usize {
self.0.len()
}
fn __repr__(&self) -> String {
format!("{self:?}")
}
fn __add__(&self, other: PyBytes) -> PyBytes {
let total_length = self.0.len() + other.0.len();
let mut new_buffer = BytesMut::with_capacity(total_length);
new_buffer.extend_from_slice(&self.0);
new_buffer.extend_from_slice(&other.0);
new_buffer.into()
}
fn __contains__(&self, item: PyBytes) -> bool {
self.0
.windows(item.0.len())
.any(|window| window == item.as_slice())
}
fn __eq__(&self, other: PyBytes) -> bool {
self.0.as_ref() == other.0.as_ref()
}
fn __getitem__<'py>(
&self,
py: Python<'py>,
key: BytesGetItemKey<'py>,
) -> PyResult<Bound<'py, PyAny>> {
match key {
BytesGetItemKey::Int(mut index) => {
if index < 0 {
index += self.0.len() as isize;
}
if index < 0 {
return Err(PyIndexError::new_err("Index out of range"));
}
self.0
.get(index as usize)
.ok_or(PyIndexError::new_err("Index out of range"))?
.into_bound_py_any(py)
}
BytesGetItemKey::Slice(slice) => {
let s = self.slice(&slice)?;
s.into_bound_py_any(py)
}
}
}
fn __mul__(&self, value: usize) -> PyBytes {
let mut out_buf = BytesMut::with_capacity(self.0.len() * value);
(0..value).for_each(|_| out_buf.extend_from_slice(self.0.as_ref()));
out_buf.into()
}
#[allow(unsafe_code)]
unsafe fn __getbuffer__(
slf: PyRef<Self>,
view: *mut ffi::Py_buffer,
flags: c_int,
) -> PyResult<()> {
let bytes = slf.0.as_ref();
let ret = ffi::PyBuffer_FillInfo(
view,
slf.as_ptr() as *mut _,
bytes.as_ptr() as *mut _,
bytes.len().try_into().unwrap(),
1, flags,
);
if ret == -1 {
return Err(PyErr::fetch(slf.py()));
}
Ok(())
}
#[allow(unsafe_code)]
unsafe fn __releasebuffer__(&self, _view: *mut ffi::Py_buffer) {}
#[pyo3(signature = (prefix, /))]
fn removeprefix(&self, prefix: PyBytes) -> PyBytes {
if self.0.starts_with(prefix.as_ref()) {
self.0.slice(prefix.0.len()..).into()
} else {
self.0.clone().into()
}
}
#[pyo3(signature = (suffix, /))]
fn removesuffix(&self, suffix: PyBytes) -> PyBytes {
if self.0.ends_with(suffix.as_ref()) {
self.0.slice(0..self.0.len() - suffix.0.len()).into()
} else {
self.0.clone().into()
}
}
fn isalnum(&self) -> bool {
if self.0.is_empty() {
return false;
}
for c in self.0.as_ref() {
if !c.is_ascii_alphanumeric() {
return false;
}
}
true
}
fn isalpha(&self) -> bool {
if self.0.is_empty() {
return false;
}
for c in self.0.as_ref() {
if !c.is_ascii_alphabetic() {
return false;
}
}
true
}
fn isascii(&self) -> bool {
for c in self.0.as_ref() {
if !c.is_ascii() {
return false;
}
}
true
}
fn isdigit(&self) -> bool {
if self.0.is_empty() {
return false;
}
for c in self.0.as_ref() {
if !c.is_ascii_digit() {
return false;
}
}
true
}
fn islower(&self) -> bool {
let mut has_lower = false;
for c in self.0.as_ref() {
if c.is_ascii_uppercase() {
return false;
}
if !has_lower && c.is_ascii_lowercase() {
has_lower = true;
}
}
has_lower
}
fn isspace(&self) -> bool {
if self.0.is_empty() {
return false;
}
for c in self.0.as_ref() {
if !(c.is_ascii_whitespace() || *c == b'\x0b') {
return false;
}
}
true
}
fn isupper(&self) -> bool {
let mut has_upper = false;
for c in self.0.as_ref() {
if c.is_ascii_lowercase() {
return false;
}
if !has_upper && c.is_ascii_uppercase() {
has_upper = true;
}
}
has_upper
}
fn lower(&self) -> PyBytes {
self.0.to_ascii_lowercase().into()
}
fn upper(&self) -> PyBytes {
self.0.to_ascii_uppercase().into()
}
fn to_bytes<'py>(&'py self, py: Python<'py>) -> Bound<'py, pyo3::types::PyBytes> {
pyo3::types::PyBytes::new(py, &self.0)
}
}
impl<'py> FromPyObject<'py> for PyBytes {
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
let buffer = ob.extract::<PyBytesWrapper>()?;
let bytes = Bytes::from_owner(buffer);
Ok(Self(bytes))
}
}
#[derive(Debug)]
struct PyBytesWrapper(Option<PyBuffer<u8>>);
impl Drop for PyBytesWrapper {
#[allow(unsafe_code)]
fn drop(&mut self) {
let is_initialized = unsafe { ffi::Py_IsInitialized() };
if let Some(val) = self.0.take() {
if is_initialized == 0 {
std::mem::forget(val);
} else {
drop(val);
}
}
}
}
impl AsRef<[u8]> for PyBytesWrapper {
#[allow(unsafe_code)]
fn as_ref(&self) -> &[u8] {
let buffer = self.0.as_ref().expect("Buffer already disposed");
let len = buffer.item_count();
let ptr = NonNull::new(buffer.buf_ptr() as _).expect("Expected buffer ptr to be non null");
unsafe { std::slice::from_raw_parts(ptr.as_ptr() as *const u8, len) }
}
}
fn validate_buffer(buf: &PyBuffer<u8>) -> PyResult<()> {
if !buf.is_c_contiguous() {
return Err(PyValueError::new_err("Buffer is not C contiguous"));
}
if buf.strides().iter().any(|s| *s != 1) {
return Err(PyValueError::new_err(format!(
"strides other than 1 not supported, got: {:?} ",
buf.strides()
)));
}
Ok(())
}
impl<'py> FromPyObject<'py> for PyBytesWrapper {
fn extract_bound(ob: &Bound<'py, PyAny>) -> PyResult<Self> {
let buffer = ob.extract::<PyBuffer<u8>>()?;
validate_buffer(&buffer)?;
Ok(Self(Some(buffer)))
}
}
impl std::fmt::Debug for PyBytes {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("Bytes(b\"")?;
for &byte in self.0.as_ref() {
match byte {
b'\\' => f.write_str(r"\\")?,
b'"' => f.write_str("\\\"")?,
b'\n' => f.write_str(r"\n")?,
b'\r' => f.write_str(r"\r")?,
b'\t' => f.write_str(r"\t")?,
0x20..=0x7E => f.write_char(byte as char)?,
_ => write!(f, "\\x{byte:02x}")?,
}
}
f.write_str("\")")?;
Ok(())
}
}
#[derive(FromPyObject)]
enum BytesGetItemKey<'py> {
Int(isize),
Slice(Bound<'py, PySlice>),
}