1use std::collections::HashMap;
2use std::fmt::Display;
3use std::sync::Arc;
4
5use arrow_schema::{Schema, SchemaRef};
6use pyo3::exceptions::{PyTypeError, PyValueError};
7use pyo3::prelude::*;
8use pyo3::types::{PyBytes, PyCapsule, PyDict, PyTuple, PyType};
9use pyo3::{intern, IntoPyObjectExt};
10
11use crate::error::PyArrowResult;
12use crate::export::{Arro3DataType, Arro3Field, Arro3Schema, Arro3Table};
13use crate::ffi::from_python::utils::import_schema_pycapsule;
14use crate::ffi::to_python::nanoarrow::to_nanoarrow_schema;
15use crate::ffi::to_python::to_schema_pycapsule;
16use crate::input::{FieldIndexInput, MetadataInput};
17use crate::{PyDataType, PyField, PyTable};
18
19#[derive(Debug)]
23#[pyclass(module = "arro3.core._core", name = "Schema", subclass, frozen)]
24pub struct PySchema(SchemaRef);
25
26impl PySchema {
27 pub fn new(schema: SchemaRef) -> Self {
29 Self(schema)
30 }
31
32 pub fn from_arrow_pycapsule(capsule: &Bound<PyCapsule>) -> PyResult<Self> {
34 let schema_ptr = import_schema_pycapsule(capsule)?;
35 let schema =
36 Schema::try_from(schema_ptr).map_err(|err| PyTypeError::new_err(err.to_string()))?;
37 Ok(Self::new(Arc::new(schema)))
38 }
39
40 pub fn into_inner(self) -> SchemaRef {
42 self.0
43 }
44
45 pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
47 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
48 arro3_mod.getattr(intern!(py, "Schema"))?.call_method1(
49 intern!(py, "from_arrow_pycapsule"),
50 PyTuple::new(py, vec![self.__arrow_c_schema__(py)?])?,
51 )
52 }
53
54 pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
56 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
57 let capsule = to_schema_pycapsule(py, self.0.as_ref())?;
58 arro3_mod.getattr(intern!(py, "Schema"))?.call_method1(
59 intern!(py, "from_arrow_pycapsule"),
60 PyTuple::new(py, vec![capsule])?,
61 )
62 }
63
64 pub fn to_nanoarrow<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
66 to_nanoarrow_schema(py, &self.__arrow_c_schema__(py)?)
67 }
68
69 pub fn to_pyarrow(self, py: Python) -> PyResult<PyObject> {
73 let pyarrow_mod = py.import(intern!(py, "pyarrow"))?;
74 let pyarrow_obj = pyarrow_mod
75 .getattr(intern!(py, "schema"))?
76 .call1(PyTuple::new(py, vec![self.into_pyobject(py)?])?)?;
77 pyarrow_obj.into_py_any(py)
78 }
79}
80
81impl From<PySchema> for SchemaRef {
82 fn from(value: PySchema) -> Self {
83 value.0
84 }
85}
86
87impl From<&PySchema> for SchemaRef {
88 fn from(value: &PySchema) -> Self {
89 value.0.as_ref().clone().into()
90 }
91}
92
93impl From<SchemaRef> for PySchema {
94 fn from(value: SchemaRef) -> Self {
95 Self(value)
96 }
97}
98
99impl AsRef<Schema> for PySchema {
100 fn as_ref(&self) -> &Schema {
101 &self.0
102 }
103}
104
105impl Display for PySchema {
106 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
107 writeln!(f, "arro3.core.Schema")?;
108 writeln!(f, "------------")?;
109 display_schema(&self.0, f)
110 }
111}
112
113pub(crate) fn display_schema(schema: &Schema, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
114 schema.fields().iter().try_for_each(|field| {
115 f.write_str(field.name().as_str())?;
116 write!(f, ": ")?;
117 field.data_type().fmt(f)?;
118 writeln!(f)?;
119 Ok::<_, std::fmt::Error>(())
120 })?;
121 Ok(())
122}
123
124#[pymethods]
125impl PySchema {
126 #[new]
127 #[pyo3(signature = (fields, *, metadata=None))]
128 fn init(fields: Vec<PyField>, metadata: Option<MetadataInput>) -> PyResult<Self> {
129 let fields = fields
130 .into_iter()
131 .map(|field| field.into_inner())
132 .collect::<Vec<_>>();
133 let schema = PySchema::new(
134 Schema::new_with_metadata(fields, metadata.unwrap_or_default().into_string_hashmap()?)
135 .into(),
136 );
137 Ok(schema)
138 }
139
140 fn __arrow_c_schema__<'py>(&'py self, py: Python<'py>) -> PyArrowResult<Bound<'py, PyCapsule>> {
141 to_schema_pycapsule(py, self.0.as_ref())
142 }
143
144 fn __eq__(&self, other: &PySchema) -> bool {
145 self.0 == other.0
146 }
147
148 fn __getitem__(&self, key: FieldIndexInput) -> PyArrowResult<Arro3Field> {
149 self.field(key)
150 }
151
152 fn __len__(&self) -> usize {
153 self.0.fields().len()
154 }
155
156 fn __repr__(&self) -> String {
157 self.to_string()
158 }
159
160 #[classmethod]
161 fn from_arrow(_cls: &Bound<PyType>, input: Self) -> Self {
162 input
163 }
164
165 #[classmethod]
166 #[pyo3(name = "from_arrow_pycapsule")]
167 fn from_arrow_pycapsule_py(_cls: &Bound<PyType>, capsule: &Bound<PyCapsule>) -> PyResult<Self> {
168 Self::from_arrow_pycapsule(capsule)
169 }
170
171 fn append(&self, field: PyField) -> Arro3Schema {
172 let mut fields = self.0.fields().to_vec();
173 fields.push(field.into_inner());
174 Schema::new_with_metadata(fields, self.0.metadata().clone()).into()
175 }
176
177 fn empty_table(&self) -> PyResult<Arro3Table> {
178 Ok(PyTable::try_new(vec![], self.into())?.into())
179 }
180
181 fn equals(&self, other: PySchema) -> bool {
182 self.0 == other.0
183 }
184
185 fn field(&self, i: FieldIndexInput) -> PyArrowResult<Arro3Field> {
186 let index = i.into_position(&self.0)?;
187 Ok(self.0.field(index).into())
188 }
189
190 fn get_all_field_indices(&self, name: String) -> Vec<usize> {
191 let mut indices = self
192 .0
193 .fields()
194 .iter()
195 .enumerate()
196 .filter(|(_idx, field)| field.name() == name.as_str())
197 .map(|(idx, _field)| idx)
198 .collect::<Vec<_>>();
199 indices.sort();
200 indices
201 }
202
203 fn get_field_index(&self, name: String) -> PyArrowResult<usize> {
204 let indices = self
205 .0
206 .fields()
207 .iter()
208 .enumerate()
209 .filter(|(_idx, field)| field.name() == name.as_str())
210 .map(|(idx, _field)| idx)
211 .collect::<Vec<_>>();
212 if indices.len() == 1 {
213 Ok(indices[0])
214 } else {
215 Err(PyValueError::new_err("Multiple fields with given name").into())
216 }
217 }
218
219 fn insert(&self, i: usize, field: PyField) -> Arro3Schema {
220 let mut fields = self.0.fields().to_vec();
221 fields.insert(i, field.into_inner());
222 Schema::new_with_metadata(fields, self.0.metadata().clone()).into()
223 }
224
225 #[getter]
228 fn metadata<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
229 let d = PyDict::new(py);
230 self.0.metadata().iter().try_for_each(|(key, val)| {
231 d.set_item(
232 PyBytes::new(py, key.as_bytes()),
233 PyBytes::new(py, val.as_bytes()),
234 )
235 })?;
236 Ok(d)
237 }
238
239 #[getter]
240 fn metadata_str(&self) -> HashMap<String, String> {
241 self.0.metadata().clone()
242 }
243
244 #[getter]
245 fn names(&self) -> Vec<String> {
246 self.0.fields().iter().map(|f| f.name().clone()).collect()
247 }
248
249 fn remove(&self, i: usize) -> Arro3Schema {
250 let mut fields = self.0.fields().to_vec();
251 fields.remove(i);
252 Schema::new_with_metadata(fields, self.0.metadata().clone()).into()
253 }
254
255 fn remove_metadata(&self) -> Arro3Schema {
256 self.0
257 .as_ref()
258 .clone()
259 .with_metadata(Default::default())
260 .into()
261 }
262
263 fn set(&self, i: usize, field: PyField) -> Arro3Schema {
264 let mut fields = self.0.fields().to_vec();
265 fields[i] = field.into_inner();
266 Schema::new_with_metadata(fields, self.0.metadata().clone()).into()
267 }
268
269 #[getter]
270 fn types(&self) -> Vec<Arro3DataType> {
271 self.0
272 .fields()
273 .iter()
274 .map(|f| PyDataType::new(f.data_type().clone()).into())
275 .collect()
276 }
277
278 fn with_metadata(&self, metadata: MetadataInput) -> PyResult<Arro3Schema> {
279 let schema = self
280 .0
281 .as_ref()
282 .clone()
283 .with_metadata(metadata.into_string_hashmap()?);
284 Ok(schema.into())
285 }
286}