1use std::fmt::Display;
2use std::str::FromStr;
3use std::sync::Arc;
4
5use arrow_array::cast::AsArray;
6use arrow_array::timezone::Tz;
7use arrow_array::types::*;
8use arrow_array::{Array, ArrayRef, Datum, UnionArray};
9use arrow_cast::cast;
10use arrow_schema::{ArrowError, DataType, Field, FieldRef, TimeUnit};
11use indexmap::IndexMap;
12use pyo3::prelude::*;
13use pyo3::types::{PyCapsule, PyList, PyTuple, PyType};
14use pyo3::{intern, IntoPyObjectExt};
15
16use crate::error::PyArrowResult;
17use crate::export::{Arro3DataType, Arro3Field, Arro3Scalar};
18use crate::ffi::to_array_pycapsules;
19use crate::{PyArray, PyField};
20
21#[derive(Debug)]
23#[pyclass(module = "arro3.core._core", name = "Scalar", subclass, frozen)]
24pub struct PyScalar {
25 array: ArrayRef,
26 field: FieldRef,
27}
28
29impl PyScalar {
30 pub unsafe fn new_unchecked(array: ArrayRef, field: FieldRef) -> Self {
37 Self { array, field }
38 }
39
40 pub fn try_from_array_ref(array: ArrayRef) -> PyArrowResult<Self> {
42 let field = Field::new("", array.data_type().clone(), true);
43 Self::try_new(array, Arc::new(field))
44 }
45
46 pub fn try_new(array: ArrayRef, field: FieldRef) -> PyArrowResult<Self> {
51 let (array, field) = PyArray::try_new(array, field)?.into_inner();
53 if array.len() != 1 {
54 return Err(ArrowError::SchemaError(
55 "Expected array of length 1 for scalar".to_string(),
56 )
57 .into());
58 }
59
60 Ok(Self { array, field })
61 }
62
63 pub fn try_from_arrow_pycapsule(
65 schema_capsule: &Bound<PyCapsule>,
66 array_capsule: &Bound<PyCapsule>,
67 ) -> PyArrowResult<Self> {
68 let (array, field) =
69 PyArray::from_arrow_pycapsule(schema_capsule, array_capsule)?.into_inner();
70 Self::try_new(array, field)
71 }
72
73 pub fn array(&self) -> &ArrayRef {
75 &self.array
76 }
77
78 pub fn field(&self) -> &FieldRef {
80 &self.field
81 }
82
83 pub fn into_inner(self) -> (ArrayRef, FieldRef) {
85 (self.array, self.field)
86 }
87
88 pub fn to_arro3<'py>(&'py self, py: Python<'py>) -> PyResult<Bound<'py, PyAny>> {
92 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
93 arro3_mod.getattr(intern!(py, "Scalar"))?.call_method1(
94 intern!(py, "from_arrow_pycapsule"),
95 self.__arrow_c_array__(py, None)?,
96 )
97 }
98
99 pub fn into_arro3(self, py: Python) -> PyResult<Bound<PyAny>> {
103 let arro3_mod = py.import(intern!(py, "arro3.core"))?;
104 let capsules = to_array_pycapsules(py, self.field.clone(), &self.array, None)?;
105 arro3_mod
106 .getattr(intern!(py, "Scalar"))?
107 .call_method1(intern!(py, "from_arrow_pycapsule"), capsules)
108 }
109}
110
111impl Display for PyScalar {
112 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
113 write!(f, "arro3.core.Scalar<")?;
114 self.array.data_type().fmt(f)?;
115 writeln!(f, ">")?;
116 Ok(())
117 }
118}
119
120impl Datum for PyScalar {
121 fn get(&self) -> (&dyn Array, bool) {
122 (self.array.as_ref(), true)
123 }
124}
125
126#[pymethods]
127impl PyScalar {
128 #[new]
129 #[pyo3(signature = (obj, /, r#type = None, *))]
130 fn init(py: Python, obj: &Bound<PyAny>, r#type: Option<PyField>) -> PyArrowResult<Self> {
131 if let Ok(data) = obj.extract::<PyScalar>() {
132 return Ok(data);
133 }
134
135 let obj = PyList::new(py, vec![obj])?;
136 let array = PyArray::init(&obj, r#type)?;
137 let (array, field) = array.into_inner();
138 Self::try_new(array, field)
139 }
140
141 #[pyo3(signature = (requested_schema=None))]
142 fn __arrow_c_array__<'py>(
143 &'py self,
144 py: Python<'py>,
145 requested_schema: Option<Bound<'py, PyCapsule>>,
146 ) -> PyArrowResult<Bound<'py, PyTuple>> {
147 to_array_pycapsules(py, self.field.clone(), &self.array, requested_schema)
148 }
149
150 fn __eq__(&self, py: Python, other: Bound<'_, PyAny>) -> PyResult<PyObject> {
151 if let Ok(other) = other.extract::<PyScalar>() {
152 let eq = self.array == other.array && self.field == other.field;
153 eq.into_py_any(py)
154 } else {
156 let self_py = self.as_py(py)?;
159 self_py.call_method1(py, intern!(py, "__eq__"), PyTuple::new(py, vec![other])?)
160 }
161 }
162
163 fn __repr__(&self) -> String {
164 self.to_string()
165 }
166
167 #[classmethod]
168 fn from_arrow(_cls: &Bound<PyType>, input: PyScalar) -> Self {
169 input
170 }
171
172 #[classmethod]
173 #[pyo3(name = "from_arrow_pycapsule")]
174 fn from_arrow_pycapsule_py(
175 _cls: &Bound<PyType>,
176 schema_capsule: &Bound<PyCapsule>,
177 array_capsule: &Bound<PyCapsule>,
178 ) -> PyArrowResult<Self> {
179 Self::try_from_arrow_pycapsule(schema_capsule, array_capsule)
180 }
181
182 pub(crate) fn as_py(&self, py: Python) -> PyArrowResult<PyObject> {
183 if self.array.is_null(0) {
184 return Ok(py.None());
185 }
186
187 let arr = self.array.as_ref();
188 let result = match self.array.data_type() {
189 DataType::Null => py.None(),
190 DataType::Boolean => arr.as_boolean().value(0).into_py_any(py)?,
191 DataType::Int8 => arr.as_primitive::<Int8Type>().value(0).into_py_any(py)?,
192 DataType::Int16 => arr.as_primitive::<Int16Type>().value(0).into_py_any(py)?,
193 DataType::Int32 => arr.as_primitive::<Int32Type>().value(0).into_py_any(py)?,
194 DataType::Int64 => arr.as_primitive::<Int64Type>().value(0).into_py_any(py)?,
195 DataType::UInt8 => arr.as_primitive::<UInt8Type>().value(0).into_py_any(py)?,
196 DataType::UInt16 => arr.as_primitive::<UInt16Type>().value(0).into_py_any(py)?,
197 DataType::UInt32 => arr.as_primitive::<UInt32Type>().value(0).into_py_any(py)?,
198 DataType::UInt64 => arr.as_primitive::<UInt64Type>().value(0).into_py_any(py)?,
199 DataType::Float16 => {
200 f32::from(arr.as_primitive::<Float16Type>().value(0)).into_py_any(py)?
201 }
202 DataType::Float32 => arr.as_primitive::<Float32Type>().value(0).into_py_any(py)?,
203 DataType::Float64 => arr.as_primitive::<Float64Type>().value(0).into_py_any(py)?,
204 DataType::Timestamp(time_unit, tz) => {
205 if let Some(tz) = tz {
206 let tz = Tz::from_str(tz)?;
207 match time_unit {
208 TimeUnit::Second => arr
209 .as_primitive::<TimestampSecondType>()
210 .value_as_datetime_with_tz(0, tz)
211 .map(|dt| dt.fixed_offset())
212 .into_py_any(py)?,
213 TimeUnit::Millisecond => arr
214 .as_primitive::<TimestampMillisecondType>()
215 .value_as_datetime_with_tz(0, tz)
216 .map(|dt| dt.fixed_offset())
217 .into_py_any(py)?,
218 TimeUnit::Microsecond => arr
219 .as_primitive::<TimestampMicrosecondType>()
220 .value_as_datetime_with_tz(0, tz)
221 .map(|dt| dt.fixed_offset())
222 .into_py_any(py)?,
223 TimeUnit::Nanosecond => arr
224 .as_primitive::<TimestampNanosecondType>()
225 .value_as_datetime_with_tz(0, tz)
226 .map(|dt| dt.fixed_offset())
227 .into_py_any(py)?,
228 }
229 } else {
230 match time_unit {
231 TimeUnit::Second => arr
232 .as_primitive::<TimestampSecondType>()
233 .value_as_datetime(0)
234 .into_py_any(py)?,
235 TimeUnit::Millisecond => arr
236 .as_primitive::<TimestampMillisecondType>()
237 .value_as_datetime(0)
238 .into_py_any(py)?,
239 TimeUnit::Microsecond => arr
240 .as_primitive::<TimestampMicrosecondType>()
241 .value_as_datetime(0)
242 .into_py_any(py)?,
243 TimeUnit::Nanosecond => arr
244 .as_primitive::<TimestampNanosecondType>()
245 .value_as_datetime(0)
246 .into_py_any(py)?,
247 }
248 }
249 }
250 DataType::Date32 => arr
251 .as_primitive::<Date32Type>()
252 .value_as_date(0)
253 .into_py_any(py)?,
254 DataType::Date64 => arr
255 .as_primitive::<Date64Type>()
256 .value_as_date(0)
257 .into_py_any(py)?,
258 DataType::Time32(time_unit) => match time_unit {
259 TimeUnit::Second => arr
260 .as_primitive::<Time32SecondType>()
261 .value_as_time(0)
262 .into_py_any(py)?,
263 TimeUnit::Millisecond => arr
264 .as_primitive::<Time32MillisecondType>()
265 .value_as_time(0)
266 .into_py_any(py)?,
267 _ => unreachable!(),
268 },
269 DataType::Time64(time_unit) => match time_unit {
270 TimeUnit::Microsecond => arr
271 .as_primitive::<Time64MicrosecondType>()
272 .value_as_time(0)
273 .into_py_any(py)?,
274 TimeUnit::Nanosecond => arr
275 .as_primitive::<Time64NanosecondType>()
276 .value_as_time(0)
277 .into_py_any(py)?,
278
279 _ => unreachable!(),
280 },
281 DataType::Duration(time_unit) => match time_unit {
282 TimeUnit::Second => arr
283 .as_primitive::<DurationSecondType>()
284 .value_as_duration(0)
285 .into_py_any(py)?,
286 TimeUnit::Millisecond => arr
287 .as_primitive::<DurationMillisecondType>()
288 .value_as_duration(0)
289 .into_py_any(py)?,
290 TimeUnit::Microsecond => arr
291 .as_primitive::<DurationMicrosecondType>()
292 .value_as_duration(0)
293 .into_py_any(py)?,
294 TimeUnit::Nanosecond => arr
295 .as_primitive::<DurationNanosecondType>()
296 .value_as_duration(0)
297 .into_py_any(py)?,
298 },
299 DataType::Interval(_) => {
300 todo!("interval is not yet fully documented [ARROW-3097]")
302 }
303 DataType::Binary => arr.as_binary::<i32>().value(0).into_py_any(py)?,
304 DataType::FixedSizeBinary(_) => arr.as_fixed_size_binary().value(0).into_py_any(py)?,
305 DataType::LargeBinary => arr.as_binary::<i64>().value(0).into_py_any(py)?,
306 DataType::BinaryView => arr.as_binary_view().value(0).into_py_any(py)?,
307 DataType::Utf8 => arr.as_string::<i32>().value(0).into_py_any(py)?,
308 DataType::LargeUtf8 => arr.as_string::<i64>().value(0).into_py_any(py)?,
309 DataType::Utf8View => arr.as_string_view().value(0).into_py_any(py)?,
310 DataType::List(inner_field) => {
311 let inner_array = arr.as_list::<i32>().value(0);
312 list_values_to_py(py, inner_array, inner_field)?.into_py_any(py)?
313 }
314 DataType::LargeList(inner_field) => {
315 let inner_array = arr.as_list::<i64>().value(0);
316 list_values_to_py(py, inner_array, inner_field)?.into_py_any(py)?
317 }
318 DataType::FixedSizeList(inner_field, _list_size) => {
319 let inner_array = arr.as_fixed_size_list().value(0);
320 list_values_to_py(py, inner_array, inner_field)?.into_py_any(py)?
321 }
322 DataType::ListView(_inner_field) => {
323 todo!("as_list_view does not yet exist");
324 }
327 DataType::LargeListView(_inner_field) => {
328 todo!("as_list_view does not yet exist");
329 }
332 DataType::Struct(inner_fields) => {
333 let struct_array = arr.as_struct();
334 let mut dict_py_objects: IndexMap<&str, PyObject> =
335 IndexMap::with_capacity(inner_fields.len());
336 for (inner_field, column) in inner_fields.iter().zip(struct_array.columns()) {
337 let scalar =
338 unsafe { PyScalar::new_unchecked(column.clone(), inner_field.clone()) };
339 dict_py_objects.insert(inner_field.name(), scalar.as_py(py)?);
340 }
341 dict_py_objects.into_py_any(py)?
342 }
343 DataType::Union(_, _) => {
344 let array = arr.as_any().downcast_ref::<UnionArray>().unwrap();
345 let scalar = PyScalar::try_from_array_ref(array.value(0))?;
346 scalar.as_py(py)?
347 }
348 DataType::Dictionary(_, _) => {
349 let array = arr.as_any_dictionary();
350 let keys = array.keys();
351 let key = match keys.data_type() {
352 DataType::Int8 => keys.as_primitive::<Int8Type>().value(0) as usize,
353 DataType::Int16 => keys.as_primitive::<Int16Type>().value(0) as usize,
354 DataType::Int32 => keys.as_primitive::<Int32Type>().value(0) as usize,
355 DataType::Int64 => keys.as_primitive::<Int64Type>().value(0) as usize,
356 DataType::UInt8 => keys.as_primitive::<UInt8Type>().value(0) as usize,
357 DataType::UInt16 => keys.as_primitive::<UInt16Type>().value(0) as usize,
358 DataType::UInt32 => keys.as_primitive::<UInt32Type>().value(0) as usize,
359 DataType::UInt64 => keys.as_primitive::<UInt64Type>().value(0) as usize,
360 _ => unreachable!(),
363 };
364 let value = array.values().slice(key, 1);
365 PyScalar::try_from_array_ref(value)?.as_py(py)?
366 }
367
368 DataType::Decimal128(_, _) => {
379 todo!()
381 }
382 DataType::Decimal256(_, _) => {
383 todo!()
385 }
386 DataType::Map(_, _) => {
387 let array = arr.as_map();
388 let struct_arr = array.value(0);
389 let key_arr = struct_arr.column_by_name("key").unwrap();
390 let value_arr = struct_arr.column_by_name("value").unwrap();
391
392 let mut entries = Vec::with_capacity(struct_arr.len());
393 for i in 0..struct_arr.len() {
394 let py_key = PyScalar::try_from_array_ref(key_arr.slice(i, 1))?.as_py(py)?;
395 let py_value =
396 PyScalar::try_from_array_ref(value_arr.slice(i, 1))?.as_py(py)?;
397 entries.push(PyTuple::new(py, vec![py_key, py_value])?);
398 }
399
400 entries.into_py_any(py)?
401 }
402 DataType::RunEndEncoded(_, _) => {
403 todo!()
404 }
405 };
406 Ok(result)
407 }
408
409 fn cast(&self, target_type: PyField) -> PyArrowResult<Arro3Scalar> {
410 let new_field = target_type.into_inner();
411 let new_array = cast(&self.array, new_field.data_type())?;
412 Ok(PyScalar::try_new(new_array, new_field).unwrap().into())
413 }
414
415 #[getter]
416 #[pyo3(name = "field")]
417 fn py_field(&self) -> Arro3Field {
418 self.field.clone().into()
419 }
420
421 #[getter]
422 fn is_valid(&self) -> bool {
423 self.array.is_valid(0)
424 }
425
426 #[getter]
427 fn r#type(&self) -> Arro3DataType {
428 self.field.data_type().clone().into()
429 }
430}
431
432fn list_values_to_py(
433 py: Python,
434 inner_array: ArrayRef,
435 inner_field: &Arc<Field>,
436) -> PyArrowResult<Vec<PyObject>> {
437 let mut list_py_objects = Vec::with_capacity(inner_array.len());
438 for i in 0..inner_array.len() {
439 let scalar =
440 unsafe { PyScalar::new_unchecked(inner_array.slice(i, 1), inner_field.clone()) };
441 list_py_objects.push(scalar.as_py(py)?);
442 }
443 Ok(list_py_objects)
444}