Skip to main content

reifydb_routine/function/text/
substring.rs

1// SPDX-License-Identifier: Apache-2.0
2// Copyright (c) 2025 ReifyDB
3
4use reifydb_core::value::column::{Column, columns::Columns, data::ColumnData};
5use reifydb_type::{
6	util::bitvec::BitVec,
7	value::{container::utf8::Utf8Container, r#type::Type},
8};
9
10use crate::function::{Function, FunctionCapability, FunctionContext, FunctionInfo, error::FunctionError};
11
12pub struct TextSubstring {
13	info: FunctionInfo,
14}
15
16impl Default for TextSubstring {
17	fn default() -> Self {
18		Self::new()
19	}
20}
21
22impl TextSubstring {
23	pub fn new() -> Self {
24		Self {
25			info: FunctionInfo::new("text::substring"),
26		}
27	}
28}
29
30impl Function for TextSubstring {
31	fn info(&self) -> &FunctionInfo {
32		&self.info
33	}
34
35	fn capabilities(&self) -> &[FunctionCapability] {
36		&[FunctionCapability::Scalar]
37	}
38
39	fn return_type(&self, _input_types: &[Type]) -> Type {
40		Type::Utf8
41	}
42
43	fn execute(&self, ctx: &FunctionContext, args: &Columns) -> Result<Columns, FunctionError> {
44		// Validate exactly 3 arguments
45		if args.len() != 3 {
46			return Err(FunctionError::ArityMismatch {
47				function: ctx.fragment.clone(),
48				expected: 3,
49				actual: args.len(),
50			});
51		}
52
53		let text_col = &args[0];
54		let start_col = &args[1];
55		let length_col = &args[2];
56
57		let (text_data, text_bv) = text_col.data().unwrap_option();
58		let (start_data, start_bv) = start_col.data().unwrap_option();
59		let (length_data, length_bv) = length_col.data().unwrap_option();
60		let row_count = text_data.len();
61
62		match (text_data, start_data, length_data) {
63			(
64				ColumnData::Utf8 {
65					container: text_container,
66					max_bytes,
67				},
68				ColumnData::Int4(start_container),
69				ColumnData::Int4(length_container),
70			) => {
71				let mut result_data = Vec::with_capacity(text_container.data().len());
72
73				for i in 0..row_count {
74					if text_container.is_defined(i)
75						&& start_container.is_defined(i) && length_container.is_defined(i)
76					{
77						let original_str = &text_container[i];
78						let start_pos = start_container.get(i).copied().unwrap_or(0);
79						let length = length_container.get(i).copied().unwrap_or(0);
80
81						// Get the substring with proper Unicode handling
82						let chars: Vec<char> = original_str.chars().collect();
83						let chars_len = chars.len();
84
85						// Convert negative start to positive index from end
86						let start_idx = if start_pos < 0 {
87							chars_len.saturating_sub((-start_pos) as usize)
88						} else {
89							start_pos as usize
90						};
91						let length_usize = if length < 0 {
92							0
93						} else {
94							length as usize
95						};
96
97						let substring = if start_idx >= chars_len {
98							// Start position is beyond string length
99							String::new()
100						} else {
101							let end_idx = (start_idx + length_usize).min(chars_len);
102							chars[start_idx..end_idx].iter().collect()
103						};
104
105						result_data.push(substring);
106					} else {
107						result_data.push(String::new());
108					}
109				}
110
111				let result_col_data = ColumnData::Utf8 {
112					container: Utf8Container::new(result_data),
113					max_bytes: *max_bytes,
114				};
115
116				// Combine all three bitvecs
117				let mut combined_bv: Option<BitVec> = None;
118				for bv in [text_bv, start_bv, length_bv].into_iter().flatten() {
119					combined_bv = Some(match combined_bv {
120						Some(existing) => existing.and(bv),
121						None => bv.clone(),
122					});
123				}
124
125				let final_data = match combined_bv {
126					Some(bv) => ColumnData::Option {
127						inner: Box::new(result_col_data),
128						bitvec: bv,
129					},
130					None => result_col_data,
131				};
132				Ok(Columns::new(vec![Column::new(ctx.fragment.clone(), final_data)]))
133			}
134			// Handle cases where start/length are different integer types
135			(
136				ColumnData::Utf8 {
137					container: text_container,
138					max_bytes,
139				},
140				start_d,
141				length_d,
142			) => {
143				let mut result_data = Vec::with_capacity(text_container.data().len());
144
145				for i in 0..row_count {
146					if text_container.is_defined(i) {
147						let original_str = &text_container[i];
148
149						// Extract start position from various integer types
150						let start_pos = match start_d {
151							ColumnData::Int1(container) => {
152								container.get(i).map(|&v| v as i32).unwrap_or(0)
153							}
154							ColumnData::Int2(container) => {
155								container.get(i).map(|&v| v as i32).unwrap_or(0)
156							}
157							ColumnData::Int4(container) => {
158								container.get(i).copied().unwrap_or(0)
159							}
160							ColumnData::Int8(container) => {
161								container.get(i).map(|&v| v as i32).unwrap_or(0)
162							}
163							_ => 0,
164						};
165
166						// Extract length from various integer types
167						let length = match length_d {
168							ColumnData::Int1(container) => {
169								container.get(i).map(|&v| v as i32).unwrap_or(0)
170							}
171							ColumnData::Int2(container) => {
172								container.get(i).map(|&v| v as i32).unwrap_or(0)
173							}
174							ColumnData::Int4(container) => {
175								container.get(i).copied().unwrap_or(0)
176							}
177							ColumnData::Int8(container) => {
178								container.get(i).map(|&v| v as i32).unwrap_or(0)
179							}
180							_ => 0,
181						};
182
183						// Get the substring with proper Unicode handling
184						let chars: Vec<char> = original_str.chars().collect();
185						let chars_len = chars.len();
186
187						// Convert negative start to positive index from end
188						let start_idx = if start_pos < 0 {
189							chars_len.saturating_sub((-start_pos) as usize)
190						} else {
191							start_pos as usize
192						};
193						let length_usize = if length < 0 {
194							0
195						} else {
196							length as usize
197						};
198
199						let substring = if start_idx >= chars_len {
200							// Start position is beyond string length
201							String::new()
202						} else {
203							let end_idx = (start_idx + length_usize).min(chars_len);
204							chars[start_idx..end_idx].iter().collect()
205						};
206
207						result_data.push(substring);
208					} else {
209						result_data.push(String::new());
210					}
211				}
212
213				let result_col_data = ColumnData::Utf8 {
214					container: Utf8Container::new(result_data),
215					max_bytes: *max_bytes,
216				};
217
218				// Combine all three bitvecs
219				let mut combined_bv: Option<BitVec> = None;
220				for bv in [text_bv, start_bv, length_bv].into_iter().flatten() {
221					combined_bv = Some(match combined_bv {
222						Some(existing) => existing.and(bv),
223						None => bv.clone(),
224					});
225				}
226
227				let final_data = match combined_bv {
228					Some(bv) => ColumnData::Option {
229						inner: Box::new(result_col_data),
230						bitvec: bv,
231					},
232					None => result_col_data,
233				};
234				Ok(Columns::new(vec![Column::new(ctx.fragment.clone(), final_data)]))
235			}
236			(other, _, _) => Err(FunctionError::InvalidArgumentType {
237				function: ctx.fragment.clone(),
238				argument_index: 0,
239				expected: vec![Type::Utf8],
240				actual: other.get_type(),
241			}),
242		}
243	}
244}