Skip to main content

reifydb_routine/function/text/
substring.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::value::column::{ColumnWithName, buffer::ColumnBuffer, columns::Columns};
5use reifydb_type::{
6	util::bitvec::BitVec,
7	value::{container::utf8::Utf8Container, r#type::Type},
8};
9
10use crate::routine::{Function, FunctionKind, Routine, RoutineInfo, context::FunctionContext, error::RoutineError};
11
12pub struct TextSubstring {
13	info: RoutineInfo,
14}
15
16impl Default for TextSubstring {
17	fn default() -> Self {
18		Self::new()
19	}
20}
21
22impl TextSubstring {
23	pub fn new() -> Self {
24		Self {
25			info: RoutineInfo::new("text::substring"),
26		}
27	}
28}
29
30impl<'a> Routine<FunctionContext<'a>> for TextSubstring {
31	fn info(&self) -> &RoutineInfo {
32		&self.info
33	}
34
35	fn return_type(&self, _input_types: &[Type]) -> Type {
36		Type::Utf8
37	}
38
39	fn execute(&self, ctx: &mut FunctionContext<'a>, args: &Columns) -> Result<Columns, RoutineError> {
40		// Validate exactly 3 arguments
41		if args.len() != 3 {
42			return Err(RoutineError::FunctionArityMismatch {
43				function: ctx.fragment.clone(),
44				expected: 3,
45				actual: args.len(),
46			});
47		}
48
49		let text_col = &args[0];
50		let start_col = &args[1];
51		let length_col = &args[2];
52
53		let (text_data, text_bv) = text_col.unwrap_option();
54		let (start_data, start_bv) = start_col.unwrap_option();
55		let (length_data, length_bv) = length_col.unwrap_option();
56		let row_count = text_data.len();
57
58		match (text_data, start_data, length_data) {
59			(
60				ColumnBuffer::Utf8 {
61					container: text_container,
62					max_bytes,
63				},
64				ColumnBuffer::Int4(start_container),
65				ColumnBuffer::Int4(length_container),
66			) => {
67				let mut result_data = Vec::with_capacity(text_container.len());
68
69				for i in 0..row_count {
70					if text_container.is_defined(i)
71						&& start_container.is_defined(i) && length_container.is_defined(i)
72					{
73						let original_str = text_container.get(i).unwrap();
74						let start_pos = start_container.get(i).copied().unwrap_or(0);
75						let length = length_container.get(i).copied().unwrap_or(0);
76
77						// Get the substring with proper Unicode handling
78						let chars: Vec<char> = original_str.chars().collect();
79						let chars_len = chars.len();
80
81						// Convert negative start to positive index from end
82						let start_idx = if start_pos < 0 {
83							chars_len.saturating_sub((-start_pos) as usize)
84						} else {
85							start_pos as usize
86						};
87						let length_usize = if length < 0 {
88							0
89						} else {
90							length as usize
91						};
92
93						let substring = if start_idx >= chars_len {
94							// Start position is beyond string length
95							String::new()
96						} else {
97							let end_idx = (start_idx + length_usize).min(chars_len);
98							chars[start_idx..end_idx].iter().collect()
99						};
100
101						result_data.push(substring);
102					} else {
103						result_data.push(String::new());
104					}
105				}
106
107				let result_col_data = ColumnBuffer::Utf8 {
108					container: Utf8Container::new(result_data),
109					max_bytes: *max_bytes,
110				};
111
112				// Combine all three bitvecs
113				let mut combined_bv: Option<BitVec> = None;
114				for bv in [text_bv, start_bv, length_bv].into_iter().flatten() {
115					combined_bv = Some(match combined_bv {
116						Some(existing) => existing.and(bv),
117						None => bv.clone(),
118					});
119				}
120
121				let final_data = match combined_bv {
122					Some(bv) => ColumnBuffer::Option {
123						inner: Box::new(result_col_data),
124						bitvec: bv,
125					},
126					None => result_col_data,
127				};
128				Ok(Columns::new(vec![ColumnWithName::new(ctx.fragment.clone(), final_data)]))
129			}
130			// Handle cases where start/length are different integer types
131			(
132				ColumnBuffer::Utf8 {
133					container: text_container,
134					max_bytes,
135				},
136				start_d,
137				length_d,
138			) => {
139				let mut result_data = Vec::with_capacity(text_container.len());
140
141				for i in 0..row_count {
142					if text_container.is_defined(i) {
143						let original_str = text_container.get(i).unwrap();
144
145						// Extract start position from various integer types
146						let start_pos = match start_d {
147							ColumnBuffer::Int1(container) => {
148								container.get(i).map(|&v| v as i32).unwrap_or(0)
149							}
150							ColumnBuffer::Int2(container) => {
151								container.get(i).map(|&v| v as i32).unwrap_or(0)
152							}
153							ColumnBuffer::Int4(container) => {
154								container.get(i).copied().unwrap_or(0)
155							}
156							ColumnBuffer::Int8(container) => {
157								container.get(i).map(|&v| v as i32).unwrap_or(0)
158							}
159							_ => 0,
160						};
161
162						// Extract length from various integer types
163						let length = match length_d {
164							ColumnBuffer::Int1(container) => {
165								container.get(i).map(|&v| v as i32).unwrap_or(0)
166							}
167							ColumnBuffer::Int2(container) => {
168								container.get(i).map(|&v| v as i32).unwrap_or(0)
169							}
170							ColumnBuffer::Int4(container) => {
171								container.get(i).copied().unwrap_or(0)
172							}
173							ColumnBuffer::Int8(container) => {
174								container.get(i).map(|&v| v as i32).unwrap_or(0)
175							}
176							_ => 0,
177						};
178
179						// Get the substring with proper Unicode handling
180						let chars: Vec<char> = original_str.chars().collect();
181						let chars_len = chars.len();
182
183						// Convert negative start to positive index from end
184						let start_idx = if start_pos < 0 {
185							chars_len.saturating_sub((-start_pos) as usize)
186						} else {
187							start_pos as usize
188						};
189						let length_usize = if length < 0 {
190							0
191						} else {
192							length as usize
193						};
194
195						let substring = if start_idx >= chars_len {
196							// Start position is beyond string length
197							String::new()
198						} else {
199							let end_idx = (start_idx + length_usize).min(chars_len);
200							chars[start_idx..end_idx].iter().collect()
201						};
202
203						result_data.push(substring);
204					} else {
205						result_data.push(String::new());
206					}
207				}
208
209				let result_col_data = ColumnBuffer::Utf8 {
210					container: Utf8Container::new(result_data),
211					max_bytes: *max_bytes,
212				};
213
214				// Combine all three bitvecs
215				let mut combined_bv: Option<BitVec> = None;
216				for bv in [text_bv, start_bv, length_bv].into_iter().flatten() {
217					combined_bv = Some(match combined_bv {
218						Some(existing) => existing.and(bv),
219						None => bv.clone(),
220					});
221				}
222
223				let final_data = match combined_bv {
224					Some(bv) => ColumnBuffer::Option {
225						inner: Box::new(result_col_data),
226						bitvec: bv,
227					},
228					None => result_col_data,
229				};
230				Ok(Columns::new(vec![ColumnWithName::new(ctx.fragment.clone(), final_data)]))
231			}
232			(other, _, _) => Err(RoutineError::FunctionInvalidArgumentType {
233				function: ctx.fragment.clone(),
234				argument_index: 0,
235				expected: vec![Type::Utf8],
236				actual: other.get_type(),
237			}),
238		}
239	}
240}
241
242impl Function for TextSubstring {
243	fn kinds(&self) -> &[FunctionKind] {
244		&[FunctionKind::Scalar]
245	}
246}