reifydb_function/text/
substring.rs1use reifydb_core::value::column::data::ColumnData;
5use reifydb_type::value::{container::utf8::Utf8Container, r#type::Type};
6
7use crate::{
8 ScalarFunction, ScalarFunctionContext,
9 error::{ScalarFunctionError, ScalarFunctionResult},
10 propagate_options,
11};
12
13pub struct TextSubstring;
14
15impl TextSubstring {
16 pub fn new() -> Self {
17 Self
18 }
19}
20
21impl ScalarFunction for TextSubstring {
22 fn scalar(&self, ctx: ScalarFunctionContext) -> ScalarFunctionResult<ColumnData> {
23 if let Some(result) = propagate_options(self, &ctx) {
24 return result;
25 }
26
27 let columns = ctx.columns;
28 let row_count = ctx.row_count;
29
30 if columns.len() != 3 {
32 return Err(ScalarFunctionError::ArityMismatch {
33 function: ctx.fragment.clone(),
34 expected: 3,
35 actual: columns.len(),
36 });
37 }
38
39 let text_column = columns.get(0).unwrap();
40 let start_column = columns.get(1).unwrap();
41 let length_column = columns.get(2).unwrap();
42
43 match (text_column.data(), start_column.data(), length_column.data()) {
44 (
45 ColumnData::Utf8 {
46 container: text_container,
47 max_bytes,
48 },
49 ColumnData::Int4(start_container),
50 ColumnData::Int4(length_container),
51 ) => {
52 let mut result_data = Vec::with_capacity(text_container.data().len());
53
54 for i in 0..row_count {
55 if text_container.is_defined(i)
56 && start_container.is_defined(i) && length_container.is_defined(i)
57 {
58 let original_str = &text_container[i];
59 let start_pos = start_container.get(i).copied().unwrap_or(0);
60 let length = length_container.get(i).copied().unwrap_or(0);
61
62 let chars: Vec<char> = original_str.chars().collect();
64 let chars_len = chars.len();
65
66 let start_idx = if start_pos < 0 {
68 chars_len.saturating_sub((-start_pos) as usize)
69 } else {
70 start_pos as usize
71 };
72 let length_usize = if length < 0 {
73 0
74 } else {
75 length as usize
76 };
77
78 let substring = if start_idx >= chars_len {
79 String::new()
81 } else {
82 let end_idx = (start_idx + length_usize).min(chars_len);
83 chars[start_idx..end_idx].iter().collect()
84 };
85
86 result_data.push(substring);
87 } else {
88 result_data.push(String::new());
89 }
90 }
91
92 Ok(ColumnData::Utf8 {
93 container: Utf8Container::new(result_data),
94 max_bytes: *max_bytes,
95 })
96 }
97 (
99 ColumnData::Utf8 {
100 container: text_container,
101 max_bytes,
102 },
103 start_data,
104 length_data,
105 ) => {
106 let mut result_data = Vec::with_capacity(text_container.data().len());
107
108 for i in 0..row_count {
109 if text_container.is_defined(i) {
110 let original_str = &text_container[i];
111
112 let start_pos = match start_data {
114 ColumnData::Int1(container) => {
115 container.get(i).map(|&v| v as i32).unwrap_or(0)
116 }
117 ColumnData::Int2(container) => {
118 container.get(i).map(|&v| v as i32).unwrap_or(0)
119 }
120 ColumnData::Int4(container) => {
121 container.get(i).copied().unwrap_or(0)
122 }
123 ColumnData::Int8(container) => {
124 container.get(i).map(|&v| v as i32).unwrap_or(0)
125 }
126 _ => 0,
127 };
128
129 let length = match length_data {
131 ColumnData::Int1(container) => {
132 container.get(i).map(|&v| v as i32).unwrap_or(0)
133 }
134 ColumnData::Int2(container) => {
135 container.get(i).map(|&v| v as i32).unwrap_or(0)
136 }
137 ColumnData::Int4(container) => {
138 container.get(i).copied().unwrap_or(0)
139 }
140 ColumnData::Int8(container) => {
141 container.get(i).map(|&v| v as i32).unwrap_or(0)
142 }
143 _ => 0,
144 };
145
146 let chars: Vec<char> = original_str.chars().collect();
148 let chars_len = chars.len();
149
150 let start_idx = if start_pos < 0 {
152 chars_len.saturating_sub((-start_pos) as usize)
153 } else {
154 start_pos as usize
155 };
156 let length_usize = if length < 0 {
157 0
158 } else {
159 length as usize
160 };
161
162 let substring = if start_idx >= chars_len {
163 String::new()
165 } else {
166 let end_idx = (start_idx + length_usize).min(chars_len);
167 chars[start_idx..end_idx].iter().collect()
168 };
169
170 result_data.push(substring);
171 } else {
172 result_data.push(String::new());
173 }
174 }
175
176 Ok(ColumnData::Utf8 {
177 container: Utf8Container::new(result_data),
178 max_bytes: *max_bytes,
179 })
180 }
181 (other, _, _) => Err(ScalarFunctionError::InvalidArgumentType {
182 function: ctx.fragment.clone(),
183 argument_index: 0,
184 expected: vec![Type::Utf8],
185 actual: other.get_type(),
186 }),
187 }
188 }
189
190 fn return_type(&self, _input_types: &[Type]) -> Type {
191 Type::Utf8
192 }
193}