1use std::collections::HashMap;
2use std::fmt::Display;
3use std::sync::Arc;
4
5use arrow_schema::{Schema, SchemaRef};
6use pyo3::exceptions::{PyTypeError, PyValueError};
7use pyo3::intern;
8use pyo3::prelude::*;
9use pyo3::types::{PyBytes, PyCapsule, PyDict, PyTuple, PyType};
10
11use crate::error::PyArrowResult;
12use crate::export::{Arro3DataType, Arro3Field, Arro3Schema, Arro3Table};
13use crate::ffi::from_python::utils::import_schema_pycapsule;
14use crate::ffi::to_python::nanoarrow::to_nanoarrow_schema;
15use crate::ffi::to_python::to_schema_pycapsule;
16use crate::input::{FieldIndexInput, MetadataInput};
17use crate::{PyDataType, PyField, PyTable};
18
19#[derive(Debug)]
23#[pyclass(module = "arro3.core._core", name = "Schema", subclass, frozen)]
24pub struct PySchema(SchemaRef);
25
26impl PySchema {
27 pub fn new(schema: SchemaRef) -> Self {
29 Self(schema)
30 }
31
32 pub fn from_arrow_pycapsule(capsule: &Bound<PyCapsule>) -> PyResult<Self> {
34 let schema_ptr = import_schema_pycapsule(capsule)?;
35 let schema =
36 Schema::try_from(schema_ptr).map_err(|err| PyTypeError::new_err(err.to_string()))?;
37 Ok(Self::new(Arc::new(schema)))
38 }
39
40 pub fn into_inner(self) -> SchemaRef {
42 self.0
43 }
44
45 pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
47 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
48 arro3_mod.getattr(intern!(py, "Schema"))?.call_method1(
49 intern!(py, "from_arrow_pycapsule"),
50 PyTuple::new(py, vec![self.__arrow_c_schema__(py)?])?,
51 )
52 }
53
54 pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
56 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
57 let capsule = to_schema_pycapsule(py, self.0.as_ref())?;
58 arro3_mod.getattr(intern!(py, "Schema"))?.call_method1(
59 intern!(py, "from_arrow_pycapsule"),
60 PyTuple::new(py, vec![capsule])?,
61 )
62 }
63
64 pub fn to_nanoarrow<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
66 to_nanoarrow_schema(py, &self.__arrow_c_schema__(py)?)
67 }
68
69 pub fn into_pyarrow(self, py: Python) -> PyResult<Bound<PyAny>> {
73 let pyarrow_mod = py.import(intern!(py, "pyarrow"))?;
74 pyarrow_mod
75 .getattr(intern!(py, "schema"))?
76 .call1(PyTuple::new(py, vec![self.into_pyobject(py)?])?)
77 }
78}
79
80impl From<PySchema> for SchemaRef {
81 fn from(value: PySchema) -> Self {
82 value.0
83 }
84}
85
86impl From<&PySchema> for SchemaRef {
87 fn from(value: &PySchema) -> Self {
88 value.0.as_ref().clone().into()
89 }
90}
91
92impl From<SchemaRef> for PySchema {
93 fn from(value: SchemaRef) -> Self {
94 Self(value)
95 }
96}
97
98impl AsRef<Schema> for PySchema {
99 fn as_ref(&self) -> &Schema {
100 &self.0
101 }
102}
103
104impl Display for PySchema {
105 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
106 writeln!(f, "arro3.core.Schema")?;
107 writeln!(f, "------------")?;
108 display_schema(&self.0, f)
109 }
110}
111
112pub(crate) fn display_schema(schema: &Schema, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 schema.fields().iter().try_for_each(|field| {
114 f.write_str(field.name().as_str())?;
115 write!(f, ": ")?;
116 field.data_type().fmt(f)?;
117 writeln!(f)?;
118 Ok::<_, std::fmt::Error>(())
119 })?;
120 Ok(())
121}
122
123#[pymethods]
124impl PySchema {
125 #[new]
126 #[pyo3(signature = (fields, *, metadata=None))]
127 fn init(fields: Vec<PyField>, metadata: Option<MetadataInput>) -> PyResult<Self> {
128 let fields = fields
129 .into_iter()
130 .map(|field| field.into_inner())
131 .collect::<Vec<_>>();
132 let schema = PySchema::new(
133 Schema::new_with_metadata(fields, metadata.unwrap_or_default().into_string_hashmap()?)
134 .into(),
135 );
136 Ok(schema)
137 }
138
139 fn __arrow_c_schema__<'py>(&'py self, py: Python<'py>) -> PyArrowResult<Bound<'py, PyCapsule>> {
140 to_schema_pycapsule(py, self.0.as_ref())
141 }
142
143 fn __eq__(&self, other: &PySchema) -> bool {
144 self.0 == other.0
145 }
146
147 fn __getitem__(&self, key: FieldIndexInput) -> PyArrowResult<Arro3Field> {
148 self.field(key)
149 }
150
151 fn __len__(&self) -> usize {
152 self.0.fields().len()
153 }
154
155 fn __repr__(&self) -> String {
156 self.to_string()
157 }
158
159 #[classmethod]
160 fn from_arrow(_cls: &Bound<PyType>, input: Self) -> Self {
161 input
162 }
163
164 #[classmethod]
165 #[pyo3(name = "from_arrow_pycapsule")]
166 fn from_arrow_pycapsule_py(_cls: &Bound<PyType>, capsule: &Bound<PyCapsule>) -> PyResult<Self> {
167 Self::from_arrow_pycapsule(capsule)
168 }
169
170 fn append(&self, field: PyField) -> Arro3Schema {
171 let mut fields = self.0.fields().to_vec();
172 fields.push(field.into_inner());
173 Schema::new_with_metadata(fields, self.0.metadata().clone()).into()
174 }
175
176 fn empty_table(&self) -> PyResult<Arro3Table> {
177 Ok(PyTable::try_new(vec![], self.into())?.into())
178 }
179
180 fn equals(&self, other: PySchema) -> bool {
181 self.0 == other.0
182 }
183
184 fn field(&self, i: FieldIndexInput) -> PyArrowResult<Arro3Field> {
185 let index = i.into_position(&self.0)?;
186 Ok(self.0.field(index).into())
187 }
188
189 fn get_all_field_indices(&self, name: String) -> Vec<usize> {
190 let mut indices = self
191 .0
192 .fields()
193 .iter()
194 .enumerate()
195 .filter(|(_idx, field)| field.name() == name.as_str())
196 .map(|(idx, _field)| idx)
197 .collect::<Vec<_>>();
198 indices.sort();
199 indices
200 }
201
202 fn get_field_index(&self, name: String) -> PyArrowResult<usize> {
203 let indices = self
204 .0
205 .fields()
206 .iter()
207 .enumerate()
208 .filter(|(_idx, field)| field.name() == name.as_str())
209 .map(|(idx, _field)| idx)
210 .collect::<Vec<_>>();
211 if indices.len() == 1 {
212 Ok(indices[0])
213 } else {
214 Err(PyValueError::new_err("Multiple fields with given name").into())
215 }
216 }
217
218 fn insert(&self, i: usize, field: PyField) -> Arro3Schema {
219 let mut fields = self.0.fields().to_vec();
220 fields.insert(i, field.into_inner());
221 Schema::new_with_metadata(fields, self.0.metadata().clone()).into()
222 }
223
224 #[getter]
227 fn metadata<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyDict>> {
228 let d = PyDict::new(py);
229 self.0.metadata().iter().try_for_each(|(key, val)| {
230 d.set_item(
231 PyBytes::new(py, key.as_bytes()),
232 PyBytes::new(py, val.as_bytes()),
233 )
234 })?;
235 Ok(d)
236 }
237
238 #[getter]
239 fn metadata_str(&self) -> HashMap<String, String> {
240 self.0.metadata().clone()
241 }
242
243 #[getter]
244 fn names(&self) -> Vec<String> {
245 self.0.fields().iter().map(|f| f.name().clone()).collect()
246 }
247
248 fn remove(&self, i: usize) -> Arro3Schema {
249 let mut fields = self.0.fields().to_vec();
250 fields.remove(i);
251 Schema::new_with_metadata(fields, self.0.metadata().clone()).into()
252 }
253
254 fn remove_metadata(&self) -> Arro3Schema {
255 self.0
256 .as_ref()
257 .clone()
258 .with_metadata(Default::default())
259 .into()
260 }
261
262 fn set(&self, i: usize, field: PyField) -> Arro3Schema {
263 let mut fields = self.0.fields().to_vec();
264 fields[i] = field.into_inner();
265 Schema::new_with_metadata(fields, self.0.metadata().clone()).into()
266 }
267
268 #[getter]
269 fn types(&self) -> Vec<Arro3DataType> {
270 self.0
271 .fields()
272 .iter()
273 .map(|f| PyDataType::new(f.data_type().clone()).into())
274 .collect()
275 }
276
277 fn with_metadata(&self, metadata: MetadataInput) -> PyResult<Arro3Schema> {
278 let schema = self
279 .0
280 .as_ref()
281 .clone()
282 .with_metadata(metadata.into_string_hashmap()?);
283 Ok(schema.into())
284 }
285}