local assign = require(script.Parent.assign)
local ComponentLifecyclePhase = require(script.Parent.ComponentLifecyclePhase)
local Type = require(script.Parent.Type)
local Symbol = require(script.Parent.Symbol)
local invalidSetStateMessages = require(script.Parent.invalidSetStateMessages)
local internalAssert = require(script.Parent.internalAssert)
local config = require(script.Parent.GlobalConfig).get()
local MAX_PENDING_UPDATES = 100
local InternalData = Symbol.named("InternalData")
local componentMissingRenderMessage = [[
The component %q is missing the `render` method.
`render` must be defined when creating a Roact component!]]
local tooManyUpdatesMessage = [[
The component %q has reached the setState update recursion limit.
When using `setState` in `didUpdate`, make sure that it won't repeat infinitely!]]
local componentClassMetatable = {}
function componentClassMetatable:__tostring()
return self.__componentName
end
local Component = {}
setmetatable(Component, componentClassMetatable)
Component[Type] = Type.StatefulComponentClass
Component.__index = Component
Component.__componentName = "Component"
function Component:extend(name)
if config.typeChecks then
assert(Type.of(self) == Type.StatefulComponentClass, "Invalid `self` argument to `extend`.")
assert(typeof(name) == "string", "Component class name must be a string")
end
local class = {}
for key, value in pairs(self) do
if key ~= "extend" then
class[key] = value
end
end
class[Type] = Type.StatefulComponentClass
class.__index = class
class.__componentName = name
setmetatable(class, componentClassMetatable)
return class
end
function Component:__getDerivedState(incomingProps, incomingState)
if config.internalTypeChecks then
internalAssert(Type.of(self) == Type.StatefulComponentInstance, "Invalid use of `__getDerivedState`")
end
local internalData = self[InternalData]
local componentClass = internalData.componentClass
if componentClass.getDerivedStateFromProps ~= nil then
local derivedState = componentClass.getDerivedStateFromProps(incomingProps, incomingState)
if derivedState ~= nil then
if config.typeChecks then
assert(typeof(derivedState) == "table", "getDerivedStateFromProps must return a table!")
end
return derivedState
end
end
return nil
end
function Component:setState(mapState)
if config.typeChecks then
assert(Type.of(self) == Type.StatefulComponentInstance, "Invalid `self` argument to `extend`.")
end
local internalData = self[InternalData]
local lifecyclePhase = internalData.lifecyclePhase
if
lifecyclePhase == ComponentLifecyclePhase.ShouldUpdate
or lifecyclePhase == ComponentLifecyclePhase.WillUpdate
or lifecyclePhase == ComponentLifecyclePhase.Render
then
local messageTemplate = invalidSetStateMessages[internalData.lifecyclePhase]
local message = messageTemplate:format(tostring(internalData.componentClass))
error(message, 2)
elseif lifecyclePhase == ComponentLifecyclePhase.WillUnmount then
return
end
local pendingState = internalData.pendingState
local partialState
if typeof(mapState) == "function" then
partialState = mapState(pendingState or self.state, self.props)
if partialState == nil then
return
end
elseif typeof(mapState) == "table" then
partialState = mapState
else
error("Invalid argument to setState, expected function or table", 2)
end
local newState
if pendingState ~= nil then
newState = assign(pendingState, partialState)
else
newState = assign({}, self.state, partialState)
end
if lifecyclePhase == ComponentLifecyclePhase.Init then
local derivedState = self:__getDerivedState(self.props, newState)
self.state = assign(newState, derivedState)
elseif
lifecyclePhase == ComponentLifecyclePhase.DidMount
or lifecyclePhase == ComponentLifecyclePhase.DidUpdate
or lifecyclePhase == ComponentLifecyclePhase.ReconcileChildren
then
local derivedState = self:__getDerivedState(self.props, newState)
internalData.pendingState = assign(newState, derivedState)
elseif lifecyclePhase == ComponentLifecyclePhase.Idle then
self:__update(nil, newState)
else
local messageTemplate = invalidSetStateMessages.default
local message = messageTemplate:format(tostring(internalData.componentClass))
error(message, 2)
end
end
function Component:getElementTraceback()
return self[InternalData].virtualNode.currentElement.source
end
function Component:render()
local internalData = self[InternalData]
local message = componentMissingRenderMessage:format(tostring(internalData.componentClass))
error(message, 0)
end
function Component:__getContext(key)
if config.internalTypeChecks then
internalAssert(Type.of(self) == Type.StatefulComponentInstance, "Invalid use of `__getContext`")
internalAssert(key ~= nil, "Context key cannot be nil")
end
local virtualNode = self[InternalData].virtualNode
local context = virtualNode.context
return context[key]
end
function Component:__addContext(key, value)
if config.internalTypeChecks then
internalAssert(Type.of(self) == Type.StatefulComponentInstance, "Invalid use of `__addContext`")
end
local virtualNode = self[InternalData].virtualNode
if virtualNode.originalContext == nil then
virtualNode.originalContext = virtualNode.context
end
local existing = virtualNode.context
virtualNode.context = assign({}, existing, { [key] = value })
end
function Component:__validateProps(props)
if not config.propValidation then
return
end
local validator = self[InternalData].componentClass.validateProps
if validator == nil then
return
end
if typeof(validator) ~= "function" then
error(
("validateProps must be a function, but it is a %s.\nCheck the definition of the component %q."):format(
typeof(validator),
self.__componentName
)
)
end
local success, failureReason = validator(props)
if not success then
failureReason = failureReason or "<Validator function did not supply a message>"
error(
("Property validation failed in %s: %s\n\n%s"):format(
self.__componentName,
tostring(failureReason),
self:getElementTraceback() or "<enable element tracebacks>"
),
0
)
end
end
function Component:__mount(reconciler, virtualNode)
if config.internalTypeChecks then
internalAssert(Type.of(self) == Type.StatefulComponentClass, "Invalid use of `__mount`")
internalAssert(Type.of(virtualNode) == Type.VirtualNode, "Expected arg #2 to be of type VirtualNode")
end
local currentElement = virtualNode.currentElement
local hostParent = virtualNode.hostParent
local internalData = {
reconciler = reconciler,
virtualNode = virtualNode,
componentClass = self,
lifecyclePhase = ComponentLifecyclePhase.Init,
pendingState = nil,
}
local instance = {
[Type] = Type.StatefulComponentInstance,
[InternalData] = internalData,
}
setmetatable(instance, self)
virtualNode.instance = instance
local props = currentElement.props
if self.defaultProps ~= nil then
props = assign({}, self.defaultProps, props)
end
instance:__validateProps(props)
instance.props = props
local newContext = assign({}, virtualNode.legacyContext)
instance._context = newContext
instance.state = assign({}, instance:__getDerivedState(instance.props, {}))
if instance.init ~= nil then
instance:init(instance.props)
assign(instance.state, instance:__getDerivedState(instance.props, instance.state))
end
virtualNode.legacyContext = instance._context
internalData.lifecyclePhase = ComponentLifecyclePhase.Render
local renderResult = instance:render()
internalData.lifecyclePhase = ComponentLifecyclePhase.ReconcileChildren
reconciler.updateVirtualNodeWithRenderResult(virtualNode, hostParent, renderResult)
if instance.didMount ~= nil then
internalData.lifecyclePhase = ComponentLifecyclePhase.DidMount
instance:didMount()
end
if internalData.pendingState ~= nil then
instance:__update(nil, nil)
end
internalData.lifecyclePhase = ComponentLifecyclePhase.Idle
end
function Component:__unmount()
if config.internalTypeChecks then
internalAssert(Type.of(self) == Type.StatefulComponentInstance, "Invalid use of `__unmount`")
end
local internalData = self[InternalData]
local virtualNode = internalData.virtualNode
local reconciler = internalData.reconciler
if self.willUnmount ~= nil then
internalData.lifecyclePhase = ComponentLifecyclePhase.WillUnmount
self:willUnmount()
end
for _, childNode in pairs(virtualNode.children) do
reconciler.unmountVirtualNode(childNode)
end
end
function Component:__update(updatedElement, updatedState)
if config.internalTypeChecks then
internalAssert(Type.of(self) == Type.StatefulComponentInstance, "Invalid use of `__update`")
internalAssert(
Type.of(updatedElement) == Type.Element or updatedElement == nil,
"Expected arg #1 to be of type Element or nil"
)
internalAssert(
typeof(updatedState) == "table" or updatedState == nil,
"Expected arg #2 to be of type table or nil"
)
end
local internalData = self[InternalData]
local componentClass = internalData.componentClass
local newProps = self.props
if updatedElement ~= nil then
newProps = updatedElement.props
if componentClass.defaultProps ~= nil then
newProps = assign({}, componentClass.defaultProps, newProps)
end
self:__validateProps(newProps)
end
local updateCount = 0
repeat
local finalState
local pendingState = nil
if internalData.pendingState ~= nil then
pendingState = internalData.pendingState
internalData.pendingState = nil
end
if updatedState ~= nil or newProps ~= self.props then
if pendingState == nil then
finalState = updatedState or self.state
else
finalState = assign(pendingState, updatedState)
end
local derivedState = self:__getDerivedState(newProps, finalState)
if derivedState ~= nil then
finalState = assign({}, finalState, derivedState)
end
updatedState = nil
else
finalState = pendingState
end
if not self:__resolveUpdate(newProps, finalState) then
return false
end
updateCount = updateCount + 1
if updateCount > MAX_PENDING_UPDATES then
error(tooManyUpdatesMessage:format(tostring(internalData.componentClass)), 3)
end
until internalData.pendingState == nil
return true
end
function Component:__resolveUpdate(incomingProps, incomingState)
if config.internalTypeChecks then
internalAssert(Type.of(self) == Type.StatefulComponentInstance, "Invalid use of `__resolveUpdate`")
end
local internalData = self[InternalData]
local virtualNode = internalData.virtualNode
local reconciler = internalData.reconciler
local oldProps = self.props
local oldState = self.state
if incomingProps == nil then
incomingProps = oldProps
end
if incomingState == nil then
incomingState = oldState
end
if self.shouldUpdate ~= nil then
internalData.lifecyclePhase = ComponentLifecyclePhase.ShouldUpdate
local continueWithUpdate = self:shouldUpdate(incomingProps, incomingState)
if not continueWithUpdate then
internalData.lifecyclePhase = ComponentLifecyclePhase.Idle
return false
end
end
if self.willUpdate ~= nil then
internalData.lifecyclePhase = ComponentLifecyclePhase.WillUpdate
self:willUpdate(incomingProps, incomingState)
end
internalData.lifecyclePhase = ComponentLifecyclePhase.Render
self.props = incomingProps
self.state = incomingState
local renderResult = virtualNode.instance:render()
internalData.lifecyclePhase = ComponentLifecyclePhase.ReconcileChildren
reconciler.updateVirtualNodeWithRenderResult(virtualNode, virtualNode.hostParent, renderResult)
if self.didUpdate ~= nil then
internalData.lifecyclePhase = ComponentLifecyclePhase.DidUpdate
self:didUpdate(oldProps, oldState)
end
internalData.lifecyclePhase = ComponentLifecyclePhase.Idle
return true
end
return Component