arrow_schema/extension/canonical/
opaque.rs1use serde_core::ser::SerializeStruct;
23use serde_core::{
24 Deserialize, Deserializer, Serialize, Serializer,
25 de::{MapAccess, Visitor},
26};
27
28use crate::{ArrowError, DataType, extension::ExtensionType};
29
30#[derive(Debug, Clone, PartialEq)]
42pub struct Opaque(OpaqueMetadata);
43
44impl Opaque {
45 pub fn new(type_name: impl Into<String>, vendor_name: impl Into<String>) -> Self {
47 Self(OpaqueMetadata::new(type_name, vendor_name))
48 }
49
50 pub fn type_name(&self) -> &str {
52 self.0.type_name()
53 }
54
55 pub fn vendor_name(&self) -> &str {
57 self.0.vendor_name()
58 }
59}
60
61impl From<OpaqueMetadata> for Opaque {
62 fn from(value: OpaqueMetadata) -> Self {
63 Self(value)
64 }
65}
66
67#[derive(Debug, Clone, PartialEq)]
69pub struct OpaqueMetadata {
70 type_name: String,
72
73 vendor_name: String,
75}
76
77impl Serialize for OpaqueMetadata {
78 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
79 where
80 S: Serializer,
81 {
82 let mut state = serializer.serialize_struct("OpaqueMetadata", 2)?;
83 state.serialize_field("type_name", &self.type_name)?;
84 state.serialize_field("vendor_name", &self.vendor_name)?;
85 state.end()
86 }
87}
88
89#[derive(Debug)]
90enum MetadataField {
91 TypeName,
92 VendorName,
93}
94
95struct MetadataFieldVisitor;
96
97impl<'de> Visitor<'de> for MetadataFieldVisitor {
98 type Value = MetadataField;
99
100 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
101 formatter.write_str("`type_name` or `vendor_name`")
102 }
103
104 fn visit_str<E>(self, value: &str) -> Result<MetadataField, E>
105 where
106 E: serde_core::de::Error,
107 {
108 match value {
109 "type_name" => Ok(MetadataField::TypeName),
110 "vendor_name" => Ok(MetadataField::VendorName),
111 _ => Err(serde_core::de::Error::unknown_field(
112 value,
113 &["type_name", "vendor_name"],
114 )),
115 }
116 }
117}
118
119impl<'de> Deserialize<'de> for MetadataField {
120 fn deserialize<D>(deserializer: D) -> Result<MetadataField, D::Error>
121 where
122 D: Deserializer<'de>,
123 {
124 deserializer.deserialize_identifier(MetadataFieldVisitor)
125 }
126}
127
128struct OpaqueMetadataVisitor;
129
130impl<'de> Visitor<'de> for OpaqueMetadataVisitor {
131 type Value = OpaqueMetadata;
132
133 fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
134 formatter.write_str("struct OpaqueMetadata")
135 }
136
137 fn visit_seq<V>(self, mut seq: V) -> Result<OpaqueMetadata, V::Error>
138 where
139 V: serde_core::de::SeqAccess<'de>,
140 {
141 let type_name = seq
142 .next_element()?
143 .ok_or_else(|| serde_core::de::Error::invalid_length(0, &self))?;
144 let vendor_name = seq
145 .next_element()?
146 .ok_or_else(|| serde_core::de::Error::invalid_length(1, &self))?;
147 Ok(OpaqueMetadata {
148 type_name,
149 vendor_name,
150 })
151 }
152
153 fn visit_map<V>(self, mut map: V) -> Result<OpaqueMetadata, V::Error>
154 where
155 V: MapAccess<'de>,
156 {
157 let mut type_name = None;
158 let mut vendor_name = None;
159
160 while let Some(key) = map.next_key()? {
161 match key {
162 MetadataField::TypeName => {
163 if type_name.is_some() {
164 return Err(serde_core::de::Error::duplicate_field("type_name"));
165 }
166 type_name = Some(map.next_value()?);
167 }
168 MetadataField::VendorName => {
169 if vendor_name.is_some() {
170 return Err(serde_core::de::Error::duplicate_field("vendor_name"));
171 }
172 vendor_name = Some(map.next_value()?);
173 }
174 }
175 }
176
177 let type_name =
178 type_name.ok_or_else(|| serde_core::de::Error::missing_field("type_name"))?;
179 let vendor_name =
180 vendor_name.ok_or_else(|| serde_core::de::Error::missing_field("vendor_name"))?;
181
182 Ok(OpaqueMetadata {
183 type_name,
184 vendor_name,
185 })
186 }
187}
188
189impl<'de> Deserialize<'de> for OpaqueMetadata {
190 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
191 where
192 D: Deserializer<'de>,
193 {
194 deserializer.deserialize_struct(
195 "OpaqueMetadata",
196 &["type_name", "vendor_name"],
197 OpaqueMetadataVisitor,
198 )
199 }
200}
201
202impl OpaqueMetadata {
203 pub fn new(type_name: impl Into<String>, vendor_name: impl Into<String>) -> Self {
205 OpaqueMetadata {
206 type_name: type_name.into(),
207 vendor_name: vendor_name.into(),
208 }
209 }
210
211 pub fn type_name(&self) -> &str {
213 &self.type_name
214 }
215
216 pub fn vendor_name(&self) -> &str {
218 &self.vendor_name
219 }
220}
221
222impl ExtensionType for Opaque {
223 const NAME: &'static str = "arrow.opaque";
224
225 type Metadata = OpaqueMetadata;
226
227 fn metadata(&self) -> &Self::Metadata {
228 &self.0
229 }
230
231 fn serialize_metadata(&self) -> Option<String> {
232 Some(serde_json::to_string(self.metadata()).expect("metadata serialization"))
233 }
234
235 fn deserialize_metadata(metadata: Option<&str>) -> Result<Self::Metadata, ArrowError> {
236 metadata.map_or_else(
237 || {
238 Err(ArrowError::InvalidArgumentError(
239 "Opaque extension types requires metadata".to_owned(),
240 ))
241 },
242 |value| {
243 serde_json::from_str(value).map_err(|e| {
244 ArrowError::InvalidArgumentError(format!(
245 "Opaque metadata deserialization failed: {e}"
246 ))
247 })
248 },
249 )
250 }
251
252 fn supports_data_type(&self, _data_type: &DataType) -> Result<(), ArrowError> {
253 Ok(())
255 }
256
257 fn try_new(_data_type: &DataType, metadata: Self::Metadata) -> Result<Self, ArrowError> {
258 Ok(Self::from(metadata))
259 }
260}
261
262#[cfg(test)]
263mod tests {
264 #[cfg(feature = "canonical_extension_types")]
265 use crate::extension::CanonicalExtensionType;
266 use crate::{
267 Field,
268 extension::{EXTENSION_TYPE_METADATA_KEY, EXTENSION_TYPE_NAME_KEY},
269 };
270
271 use super::*;
272
273 #[test]
274 fn valid() -> Result<(), ArrowError> {
275 let opaque = Opaque::new("name", "vendor");
276 let mut field = Field::new("", DataType::Null, false);
277 field.try_with_extension_type(opaque.clone())?;
278 assert_eq!(field.try_extension_type::<Opaque>()?, opaque);
279 #[cfg(feature = "canonical_extension_types")]
280 assert_eq!(
281 field.try_canonical_extension_type()?,
282 CanonicalExtensionType::Opaque(opaque)
283 );
284 Ok(())
285 }
286
287 #[test]
288 #[should_panic(expected = "Field extension type name missing")]
289 fn missing_name() {
290 let field = Field::new("", DataType::Null, false).with_metadata(
291 [(
292 EXTENSION_TYPE_METADATA_KEY.to_owned(),
293 r#"{ "type_name": "type", "vendor_name": "vendor" }"#.to_owned(),
294 )]
295 .into_iter()
296 .collect(),
297 );
298 field.extension_type::<Opaque>();
299 }
300
301 #[test]
302 #[should_panic(expected = "Opaque extension types requires metadata")]
303 fn missing_metadata() {
304 let field = Field::new("", DataType::Null, false).with_metadata(
305 [(EXTENSION_TYPE_NAME_KEY.to_owned(), Opaque::NAME.to_owned())]
306 .into_iter()
307 .collect(),
308 );
309 field.extension_type::<Opaque>();
310 }
311
312 #[test]
313 #[should_panic(
314 expected = "Opaque metadata deserialization failed: missing field `vendor_name`"
315 )]
316 fn invalid_metadata() {
317 let field = Field::new("", DataType::Null, false).with_metadata(
318 [
319 (EXTENSION_TYPE_NAME_KEY.to_owned(), Opaque::NAME.to_owned()),
320 (
321 EXTENSION_TYPE_METADATA_KEY.to_owned(),
322 r#"{ "type_name": "no-vendor" }"#.to_owned(),
323 ),
324 ]
325 .into_iter()
326 .collect(),
327 );
328 field.extension_type::<Opaque>();
329 }
330}