1use crate::function::Function;
7use crate::helpers::{HasModule, NodeExt};
8use crate::raw::{Node, NodeKind};
9use crate::symbol::Symbol;
10use crate::types::{GenericSignature, ImplFunctionType, TypeRef};
11use crate::witness_table::ProtocolConformance;
12
13pub enum Thunk<'ctx> {
15 Reabstraction(ReabstractionThunk<'ctx>),
17 ProtocolWitness(ProtocolWitnessThunk<'ctx>),
19 AutoDiff(AutoDiffThunk<'ctx>),
21 Dispatch {
23 inner: Box<Symbol<'ctx>>,
24 kind: DispatchKind,
25 raw: Node<'ctx>,
26 },
27 PartialApply {
29 inner: Option<Box<Symbol<'ctx>>>,
30 is_objc: bool,
31 raw: Node<'ctx>,
32 },
33 Other {
35 kind: OtherThunkKind,
36 inner: Option<Box<Symbol<'ctx>>>,
37 raw: Node<'ctx>,
38 },
39}
40
41impl std::fmt::Debug for Thunk<'_> {
42 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
43 match self {
44 Self::Reabstraction(t) => f.debug_tuple("Reabstraction").field(t).finish(),
45 Self::ProtocolWitness(t) => f.debug_tuple("ProtocolWitness").field(t).finish(),
46 Self::AutoDiff(t) => f.debug_tuple("AutoDiff").field(t).finish(),
47 Self::Dispatch { inner, kind, .. } => f
48 .debug_struct("Dispatch")
49 .field("inner", inner)
50 .field("kind", kind)
51 .finish(),
52 Self::PartialApply { inner, is_objc, .. } => f
53 .debug_struct("PartialApply")
54 .field("inner", inner)
55 .field("is_objc", is_objc)
56 .finish(),
57 Self::Other { kind, inner, .. } => f
58 .debug_struct("Other")
59 .field("kind", kind)
60 .field("inner", inner)
61 .finish(),
62 }
63 }
64}
65
66impl<'ctx> Thunk<'ctx> {
67 pub fn new(raw: Node<'ctx>) -> Self {
69 match raw.kind() {
70 NodeKind::ReabstractionThunk
72 | NodeKind::ReabstractionThunkHelper
73 | NodeKind::ReabstractionThunkHelperWithSelf
74 | NodeKind::ReabstractionThunkHelperWithGlobalActor
75 | NodeKind::AutoDiffSelfReorderingReabstractionThunk => {
76 Self::Reabstraction(ReabstractionThunk::new(raw))
77 }
78
79 NodeKind::ProtocolWitness | NodeKind::ProtocolSelfConformanceWitness => {
81 Self::ProtocolWitness(ProtocolWitnessThunk::new(raw))
82 }
83
84 NodeKind::AutoDiffSubsetParametersThunk | NodeKind::AutoDiffDerivativeVTableThunk => {
86 Self::AutoDiff(AutoDiffThunk::new(raw))
87 }
88
89 NodeKind::DispatchThunk => Self::Dispatch {
91 inner: Box::new(
92 raw.child(0)
93 .map(Symbol::classify_node)
94 .unwrap_or(Symbol::Other(raw)),
95 ),
96 kind: DispatchKind::Protocol,
97 raw,
98 },
99 NodeKind::VTableThunk => Self::Dispatch {
100 inner: Box::new(
101 raw.child(0)
102 .map(Symbol::classify_node)
103 .unwrap_or(Symbol::Other(raw)),
104 ),
105 kind: DispatchKind::VTable,
106 raw,
107 },
108 NodeKind::DistributedThunk => {
109 let inner = raw.child(0).map(Symbol::classify_node).or_else(|| {
111 raw.descendants()
113 .find(|d| d.kind() == NodeKind::Function)
114 .map(Symbol::classify_node)
115 });
116 Self::Dispatch {
117 inner: Box::new(inner.unwrap_or(Symbol::Other(raw))),
118 kind: DispatchKind::Distributed,
119 raw,
120 }
121 }
122
123 NodeKind::PartialApplyForwarder => Self::PartialApply {
125 inner: find_inner_symbol(raw),
126 is_objc: false,
127 raw,
128 },
129 NodeKind::PartialApplyObjCForwarder => Self::PartialApply {
130 inner: find_inner_symbol(raw),
131 is_objc: true,
132 raw,
133 },
134
135 NodeKind::CurryThunk => Self::Other {
137 kind: OtherThunkKind::Curry,
138 inner: raw.child(0).map(|c| Box::new(Symbol::classify_node(c))),
139 raw,
140 },
141 NodeKind::KeyPathGetterThunkHelper => Self::Other {
142 kind: OtherThunkKind::KeyPathGetter,
143 inner: raw.child(0).map(|c| Box::new(Symbol::classify_node(c))),
144 raw,
145 },
146 NodeKind::KeyPathSetterThunkHelper => Self::Other {
147 kind: OtherThunkKind::KeyPathSetter,
148 inner: raw.child(0).map(|c| Box::new(Symbol::classify_node(c))),
149 raw,
150 },
151 NodeKind::KeyPathUnappliedMethodThunkHelper
152 | NodeKind::KeyPathAppliedMethodThunkHelper => Self::Other {
153 kind: OtherThunkKind::KeyPathMethod,
154 inner: raw.child(0).map(|c| Box::new(Symbol::classify_node(c))),
155 raw,
156 },
157 NodeKind::KeyPathEqualsThunkHelper => Self::Other {
158 kind: OtherThunkKind::KeyPathEquals,
159 inner: raw.child(0).map(|c| Box::new(Symbol::classify_node(c))),
160 raw,
161 },
162 NodeKind::KeyPathHashThunkHelper => Self::Other {
163 kind: OtherThunkKind::KeyPathHash,
164 inner: raw.child(0).map(|c| Box::new(Symbol::classify_node(c))),
165 raw,
166 },
167 NodeKind::BackDeploymentThunk => Self::Other {
168 kind: OtherThunkKind::BackDeployment,
169 inner: raw.child(0).map(|c| Box::new(Symbol::classify_node(c))),
170 raw,
171 },
172 NodeKind::BackDeploymentFallback => Self::Other {
173 kind: OtherThunkKind::BackDeploymentFallback,
174 inner: raw.child(0).map(|c| Box::new(Symbol::classify_node(c))),
175 raw,
176 },
177 NodeKind::MergedFunction => Self::Other {
178 kind: OtherThunkKind::Merged,
179 inner: find_inner_symbol_skip_metadata(raw),
180 raw,
181 },
182 NodeKind::InlinedGenericFunction => Self::Other {
183 kind: OtherThunkKind::InlinedGeneric,
184 inner: find_inner_symbol_skip_metadata(raw),
185 raw,
186 },
187 _ => Self::Other {
188 kind: OtherThunkKind::Unknown,
189 inner: raw.child(0).map(|c| Box::new(Symbol::classify_node(c))),
190 raw,
191 },
192 }
193 }
194
195 pub fn new_marker(raw: Node<'ctx>, inner: Symbol<'ctx>) -> Self {
200 match raw.kind() {
201 NodeKind::DistributedThunk => Self::Dispatch {
202 inner: Box::new(inner),
203 kind: DispatchKind::Distributed,
204 raw,
205 },
206 NodeKind::MergedFunction => Self::Other {
207 kind: OtherThunkKind::Merged,
208 inner: Some(Box::new(inner)),
209 raw,
210 },
211 NodeKind::InlinedGenericFunction => Self::Other {
212 kind: OtherThunkKind::InlinedGeneric,
213 inner: Some(Box::new(inner)),
214 raw,
215 },
216 _ => Self::Other {
217 kind: OtherThunkKind::Unknown,
218 inner: Some(Box::new(inner)),
219 raw,
220 },
221 }
222 }
223
224 pub fn raw(&self) -> Node<'ctx> {
226 match self {
227 Self::Reabstraction(t) => t.raw,
228 Self::ProtocolWitness(t) => t.raw,
229 Self::AutoDiff(t) => t.raw,
230 Self::Dispatch { raw, .. } => *raw,
231 Self::PartialApply { raw, .. } => *raw,
232 Self::Other { raw, .. } => *raw,
233 }
234 }
235
236 pub fn module(&self) -> Option<&'ctx str> {
238 match self {
239 Self::Reabstraction(t) => t.module(),
240 Self::ProtocolWitness(t) => t.module(),
241 Self::AutoDiff(t) => t.module(),
242 Self::Dispatch { raw, .. } => find_module_in_descendants(*raw),
243 Self::PartialApply { raw, .. } => find_module_in_descendants(*raw),
244 Self::Other { raw, .. } => find_module_in_descendants(*raw),
245 }
246 }
247
248 pub fn kind_name(&self) -> &'static str {
250 match self {
251 Self::Reabstraction(_) => "reabstraction thunk",
252 Self::ProtocolWitness(t) => {
253 if t.is_self_conformance {
254 "protocol self-conformance witness"
255 } else {
256 "protocol witness"
257 }
258 }
259 Self::AutoDiff(t) => t.kind.name(),
260 Self::Dispatch { kind, .. } => kind.name(),
261 Self::PartialApply { is_objc, .. } => {
262 if *is_objc {
263 "partial apply ObjC forwarder"
264 } else {
265 "partial apply forwarder"
266 }
267 }
268 Self::Other { kind, .. } => kind.name(),
269 }
270 }
271}
272
273impl std::fmt::Display for Thunk<'_> {
274 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
275 write!(f, "{}", self.raw())
276 }
277}
278
279fn find_module_in_descendants(node: Node<'_>) -> Option<&'_ str> {
280 for desc in node.descendants() {
281 if desc.kind() == NodeKind::Module {
282 return desc.text();
283 }
284 }
285 None
286}
287
288fn find_inner_symbol<'ctx>(raw: Node<'ctx>) -> Option<Box<Symbol<'ctx>>> {
290 for child in raw.children() {
291 match child.kind() {
293 NodeKind::ObjCAttribute | NodeKind::NonObjCAttribute | NodeKind::DynamicAttribute => {
294 continue;
295 }
296 _ => return Some(Box::new(Symbol::classify_node(child))),
297 }
298 }
299 None
300}
301
302fn find_inner_symbol_skip_metadata<'ctx>(raw: Node<'ctx>) -> Option<Box<Symbol<'ctx>>> {
304 for child in raw.children() {
305 match child.kind() {
307 NodeKind::SpecializationPassID | NodeKind::Number | NodeKind::Index => continue,
308 _ => return Some(Box::new(Symbol::classify_node(child))),
309 }
310 }
311 None
312}
313
314#[derive(Clone, Copy)]
316pub struct ReabstractionThunk<'ctx> {
317 raw: Node<'ctx>,
318}
319
320impl<'ctx> ReabstractionThunk<'ctx> {
321 fn new(raw: Node<'ctx>) -> Self {
322 Self { raw }
323 }
324
325 pub fn target(&self) -> Option<TypeRef<'ctx>> {
327 self.raw.extract_type_ref()
328 }
329
330 pub fn source(&self) -> Option<TypeRef<'ctx>> {
332 let mut found_first = false;
333 for child in self.raw.children() {
334 if child.kind() == NodeKind::Type {
335 if found_first {
336 return Some(TypeRef::new(child.child(0).unwrap_or(child)));
337 }
338 found_first = true;
339 }
340 }
341 None
342 }
343
344 pub fn generic_signature(&self) -> Option<GenericSignature<'ctx>> {
346 self.raw
347 .child_of_kind(NodeKind::DependentGenericSignature)
348 .map(GenericSignature::new)
349 }
350
351 pub fn module(&self) -> Option<&'ctx str> {
353 find_module_in_descendants(self.raw)
354 }
355}
356
357impl std::fmt::Debug for ReabstractionThunk<'_> {
358 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
359 let mut s = f.debug_struct("ReabstractionThunk");
360 s.field("target", &self.target());
361 s.field("source", &self.source());
362 if let Some(sig) = self.generic_signature() {
363 s.field("generic_signature", &sig);
364 }
365 s.field("module", &self.module());
366 s.finish()
367 }
368}
369
370#[derive(Clone, Copy)]
372pub struct ProtocolWitnessThunk<'ctx> {
373 raw: Node<'ctx>,
374 is_self_conformance: bool,
375}
376
377impl<'ctx> ProtocolWitnessThunk<'ctx> {
378 fn new(raw: Node<'ctx>) -> Self {
379 Self {
380 raw,
381 is_self_conformance: raw.kind() == NodeKind::ProtocolSelfConformanceWitness,
382 }
383 }
384
385 pub fn conformance(&self) -> Option<ProtocolConformance<'ctx>> {
387 self.raw
388 .child_of_kind(NodeKind::ProtocolConformance)
389 .map(ProtocolConformance::new)
390 }
391
392 pub fn inner(&self) -> Option<Symbol<'ctx>> {
394 for child in self.raw.children() {
395 if child.kind() == NodeKind::Function {
396 return Some(Symbol::classify_node(child));
397 }
398 }
399 for child in self.raw.children() {
401 if child.kind() != NodeKind::ProtocolConformance {
402 return Some(Symbol::classify_node(child));
403 }
404 }
405 None
406 }
407
408 pub fn is_self_conformance(&self) -> bool {
410 self.is_self_conformance
411 }
412
413 pub fn module(&self) -> Option<&'ctx str> {
415 self.conformance()
416 .and_then(|c| c.module())
417 .or_else(|| find_module_in_descendants(self.raw))
418 }
419}
420
421impl std::fmt::Debug for ProtocolWitnessThunk<'_> {
422 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
423 f.debug_struct("ProtocolWitnessThunk")
424 .field("is_self_conformance", &self.is_self_conformance)
425 .field("conformance", &self.conformance())
426 .field("inner", &self.inner())
427 .field("module", &self.module())
428 .finish()
429 }
430}
431
432#[derive(Clone, Copy)]
434pub struct AutoDiffThunk<'ctx> {
435 raw: Node<'ctx>,
436 kind: AutoDiffThunkKind,
437}
438
439impl<'ctx> AutoDiffThunk<'ctx> {
440 fn new(raw: Node<'ctx>) -> Self {
441 let kind = match raw.kind() {
442 NodeKind::AutoDiffSubsetParametersThunk => AutoDiffThunkKind::SubsetParameters,
443 NodeKind::AutoDiffDerivativeVTableThunk => AutoDiffThunkKind::DerivativeVTable,
444 _ => AutoDiffThunkKind::Unknown,
445 };
446 Self { raw, kind }
447 }
448
449 pub fn kind(&self) -> AutoDiffThunkKind {
451 self.kind
452 }
453
454 pub fn function(&self) -> Option<Function<'ctx>> {
456 self.raw
457 .child_of_kind(NodeKind::Function)
458 .map(Function::new)
459 }
460
461 pub fn autodiff_function_kind(&self) -> Option<u64> {
463 self.raw
464 .child_of_kind(NodeKind::AutoDiffFunctionKind)
465 .and_then(|c| c.index())
466 }
467
468 pub fn function_type(&self) -> Option<ImplFunctionType<'ctx>> {
470 for child in self.raw.children() {
471 if let Some(inner) = child.unwrap_if_kind(NodeKind::Type)
472 && inner.kind() == NodeKind::ImplFunctionType
473 {
474 return Some(ImplFunctionType::new(inner));
475 }
476 }
477 None
478 }
479
480 pub fn parameter_indices(&self) -> Option<&'ctx str> {
482 let mut found_first = false;
483 for child in self.raw.children() {
484 if child.kind() == NodeKind::IndexSubset {
485 if found_first {
486 return child.text();
487 }
488 found_first = true;
489 }
490 }
491 None
492 }
493
494 pub fn result_indices(&self) -> Option<&'ctx str> {
496 for child in self.raw.children() {
497 if child.kind() == NodeKind::IndexSubset {
498 return child.text();
499 }
500 }
501 None
502 }
503
504 pub fn to_parameter_indices(&self) -> Option<&'ctx str> {
506 let mut count = 0;
507 for child in self.raw.children() {
508 if child.kind() == NodeKind::IndexSubset {
509 count += 1;
510 if count == 3 {
511 return child.text();
512 }
513 }
514 }
515 None
516 }
517
518 pub fn module(&self) -> Option<&'ctx str> {
520 self.function()
521 .and_then(|f| f.module())
522 .or_else(|| find_module_in_descendants(self.raw))
523 }
524}
525
526impl std::fmt::Debug for AutoDiffThunk<'_> {
527 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
528 let mut s = f.debug_struct("AutoDiffThunk");
529 s.field("kind", &self.kind);
530 s.field("autodiff_function_kind", &self.autodiff_function_kind());
531 s.field("function", &self.function());
532 s.field("function_type", &self.function_type());
533 s.field("parameter_indices", &self.parameter_indices());
534 s.field("result_indices", &self.result_indices());
535 if self.kind == AutoDiffThunkKind::SubsetParameters {
536 s.field("to_parameter_indices", &self.to_parameter_indices());
537 }
538 s.field("module", &self.module());
539 s.finish()
540 }
541}
542
543#[derive(Debug, Clone, Copy, PartialEq, Eq)]
545pub enum AutoDiffThunkKind {
546 SubsetParameters,
548 DerivativeVTable,
550 Unknown,
552}
553
554impl AutoDiffThunkKind {
555 pub fn name(&self) -> &'static str {
557 match self {
558 Self::SubsetParameters => "autodiff subset parameters thunk",
559 Self::DerivativeVTable => "autodiff derivative vtable thunk",
560 Self::Unknown => "autodiff thunk",
561 }
562 }
563}
564
565#[derive(Debug, Clone, Copy, PartialEq, Eq)]
567pub enum DispatchKind {
568 Protocol,
570 VTable,
572 Distributed,
574}
575
576impl DispatchKind {
577 pub fn name(&self) -> &'static str {
579 match self {
580 Self::Protocol => "dispatch thunk",
581 Self::VTable => "vtable thunk",
582 Self::Distributed => "distributed thunk",
583 }
584 }
585}
586
587#[derive(Debug, Clone, Copy, PartialEq, Eq)]
589pub enum OtherThunkKind {
590 Curry,
592 KeyPathGetter,
594 KeyPathSetter,
596 KeyPathMethod,
598 KeyPathEquals,
600 KeyPathHash,
602 BackDeployment,
604 BackDeploymentFallback,
606 Merged,
608 InlinedGeneric,
610 Unknown,
612}
613
614impl OtherThunkKind {
615 pub fn name(&self) -> &'static str {
617 match self {
618 Self::Curry => "curry thunk",
619 Self::KeyPathGetter => "keypath getter thunk",
620 Self::KeyPathSetter => "keypath setter thunk",
621 Self::KeyPathMethod => "keypath method thunk",
622 Self::KeyPathEquals => "keypath equals thunk",
623 Self::KeyPathHash => "keypath hash thunk",
624 Self::BackDeployment => "back deployment thunk",
625 Self::BackDeploymentFallback => "back deployment fallback",
626 Self::Merged => "merged function",
627 Self::InlinedGeneric => "inlined generic function",
628 Self::Unknown => "thunk",
629 }
630 }
631}
632
633#[cfg(test)]
634mod tests {
635 use super::*;
636 use crate::raw::Context;
637
638 #[test]
639 fn test_dispatch_thunk() {
640 let ctx = Context::new();
641 let symbol = Symbol::parse(
643 &ctx,
644 "$ss10SetAlgebraPyxqd__ncSTRd__7ElementQyd__ACRtzlufCTj",
645 )
646 .unwrap();
647 assert!(symbol.is_thunk());
648 if let Symbol::Thunk(Thunk::Dispatch { inner, kind, .. }) = symbol {
649 assert_eq!(kind, DispatchKind::Protocol);
650 assert!(inner.is_constructor());
651 } else {
652 panic!("Expected dispatch thunk");
653 }
654 }
655
656 #[test]
657 fn test_protocol_witness() {
658 let ctx = Context::new();
659 let symbol = Symbol::parse(&ctx, "_TTWC13call_protocol1CS_1PS_FS1_3foofT_Si").unwrap();
661 assert!(symbol.is_thunk());
662 if let Symbol::Thunk(Thunk::ProtocolWitness(thunk)) = symbol {
663 let conformance = thunk.conformance();
665 assert!(conformance.is_some());
666 let conformance = conformance.unwrap();
667 assert!(conformance.conforming_type().is_some());
668 assert!(conformance.protocol().is_some());
669 assert_eq!(conformance.module(), Some("call_protocol"));
670
671 let inner = thunk.inner();
673 assert!(inner.is_some());
674 if let Some(Symbol::Function(func)) = inner {
675 assert_eq!(func.name(), Some("foo"));
676 assert_eq!(func.containing_type(), Some("P"));
677 } else {
678 panic!("Expected function as inner symbol");
679 }
680 } else {
681 panic!("Expected protocol witness thunk");
682 }
683 }
684}